Unverified Commit dcc49d8a authored by Billy Bradley's avatar Billy Bradley Committed by GitHub
Browse files

In assisted decoding, pass model_kwargs to model's forward call (fix...

In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
parent 1e3c9dda
...@@ -1005,11 +1005,20 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1005,11 +1005,20 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -1019,7 +1028,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1019,7 +1028,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
else: else:
position_ids = None position_ids = None
...@@ -1038,6 +1047,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1038,6 +1047,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
} }
) )
return model_inputs return model_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
...@@ -1201,11 +1211,20 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1201,11 +1211,20 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -1215,7 +1234,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1215,7 +1234,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
else: else:
position_ids = None position_ids = None
......
...@@ -737,11 +737,23 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): ...@@ -737,11 +737,23 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) if self.config.multi_query:
past_length = past_key_values[0].shape[1]
else:
past_length = past_key_values[0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -751,7 +763,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): ...@@ -751,7 +763,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
else: else:
position_ids = None position_ids = None
......
...@@ -680,11 +680,20 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -680,11 +680,20 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -694,7 +703,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -694,7 +703,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -808,10 +808,21 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -808,10 +808,21 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
input_shape = input_ids.shape input_shape = input_ids.shape
print(input_shape)
print(past_key_values[0][0].shape if past_key_values is not None else "no pkv")
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
...@@ -819,7 +830,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -819,7 +830,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
...@@ -830,7 +841,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -830,7 +841,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
print(position_ids.shape)
model_inputs.update( model_inputs.update(
{ {
"attention_mask": attention_mask, "attention_mask": attention_mask,
......
...@@ -785,11 +785,20 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ...@@ -785,11 +785,20 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -799,7 +808,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ...@@ -799,7 +808,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -912,11 +912,20 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): ...@@ -912,11 +912,20 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -926,7 +935,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): ...@@ -926,7 +935,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
else: else:
position_ids = None position_ids = None
return { return {
......
...@@ -1080,8 +1080,17 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1080,8 +1080,17 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
...@@ -1089,7 +1098,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1089,7 +1098,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -2103,9 +2103,18 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): ...@@ -2103,9 +2103,18 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
encoder_outputs=None, encoder_outputs=None,
**kwargs, **kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
......
...@@ -1367,7 +1367,16 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1367,7 +1367,16 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -1509,7 +1509,16 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1509,7 +1509,16 @@ class MarianMTModel(MarianPreTrainedModel):
) -> Dict: ) -> Dict:
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1740,7 +1749,16 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1740,7 +1749,16 @@ class MarianForCausalLM(MarianPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -948,7 +948,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel): ...@@ -948,7 +948,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
# cut decoder_input_ids if past_key_values is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
......
...@@ -1413,7 +1413,16 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1413,7 +1413,16 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1897,7 +1906,16 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1897,7 +1906,16 @@ class MBartForCausalLM(MBartPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1251,9 +1251,18 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): ...@@ -1251,9 +1251,18 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1083,8 +1083,18 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1083,8 +1083,18 @@ class MistralForCausalLM(MistralPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
# Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
...@@ -1092,7 +1102,7 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1092,7 +1102,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -605,9 +605,18 @@ class MptForCausalLM(MptPreTrainedModel): ...@@ -605,9 +605,18 @@ class MptForCausalLM(MptPreTrainedModel):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
# only last token for input_ids if past is not None # only last tokens for input_ids if past is not None
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -1836,9 +1836,18 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1836,9 +1836,18 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
encoder_outputs=None, encoder_outputs=None,
**kwargs, **kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
......
...@@ -1995,9 +1995,17 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -1995,9 +1995,17 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
# cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -1572,7 +1572,16 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): ...@@ -1572,7 +1572,16 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -2054,7 +2063,16 @@ class MvpForCausalLM(MvpPreTrainedModel): ...@@ -2054,7 +2063,16 @@ class MvpForCausalLM(MvpPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1808,7 +1808,16 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): ...@@ -1808,7 +1808,16 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -981,8 +981,17 @@ class OPTForCausalLM(OPTPreTrainedModel): ...@@ -981,8 +981,17 @@ class OPTForCausalLM(OPTPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment