Unverified Commit 12eb528b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CI ] Remove `past` in favor of `pat_key_values` (#21443)

* fix past renamed to past_key_value

* update more `past`that were ski^êd

* fixup

* remove changes made to rag

* refactor `_reorder_cache` to use `past_key_values`

* fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
parent 5b493762
......@@ -1148,9 +1148,9 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
......
......@@ -1022,9 +1022,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1029,9 +1029,9 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1575,9 +1575,9 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
}
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1187,9 +1187,9 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
......
......@@ -583,9 +583,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
......@@ -603,6 +603,6 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
return self.decoder._reorder_cache(past_key_values, beam_idx)
......@@ -1418,8 +1418,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -967,8 +967,8 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -2368,7 +2368,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
......@@ -2378,12 +2378,12 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
**kwargs,
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
......@@ -2393,9 +2393,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1774,15 +1774,15 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past:
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
......
......@@ -1773,15 +1773,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past:
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
......
......@@ -1007,8 +1007,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -669,6 +669,6 @@ class VisionEncoderDecoderModel(PreTrainedModel):
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
return self.decoder._reorder_cache(past_key_values, beam_idx)
......@@ -1396,8 +1396,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
#
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -942,8 +942,8 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -2117,9 +2117,9 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
@staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
......@@ -2364,9 +2364,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
@staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1026,9 +1026,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -988,9 +988,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......
......@@ -1180,9 +1180,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],)
return reordered_past
......@@ -2905,9 +2905,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -3344,9 +3344,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
}
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
{% endif -%}
......@@ -180,7 +180,7 @@ class TFGPT2ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs.to_tuple()
output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
......@@ -191,7 +191,9 @@ class TFGPT2ModelTester:
next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"]
output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
......@@ -213,7 +215,7 @@ class TFGPT2ModelTester:
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
output, past_key_values = model(input_ids, attention_mask=attn_mask).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
......@@ -233,7 +235,9 @@ class TFGPT2ModelTester:
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice
random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
......@@ -256,7 +260,7 @@ class TFGPT2ModelTester:
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True)
output, past = outputs.to_tuple()
output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......@@ -272,7 +276,10 @@ class TFGPT2ModelTester:
next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
)["last_hidden_state"]
output_from_past = model(
next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past=past
next_tokens,
token_type_ids=next_token_types,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
)["last_hidden_state"]
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment