"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8608bf2049a10f8d23043e1bb196707a1c1b3fe5"
Unverified Commit dcc49d8a authored by Billy Bradley's avatar Billy Bradley Committed by GitHub
Browse files

In assisted decoding, pass model_kwargs to model's forward call (fix...

In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
parent 1e3c9dda
...@@ -1466,7 +1466,16 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1466,7 +1466,16 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1719,7 +1728,16 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1719,7 +1728,16 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1671,7 +1671,16 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): ...@@ -1671,7 +1671,16 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -847,8 +847,17 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -847,8 +847,17 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
...@@ -856,7 +865,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -856,7 +865,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -1798,9 +1798,18 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1798,9 +1798,18 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"flattened_patches": flattened_patches, "flattened_patches": flattened_patches,
......
...@@ -1379,7 +1379,16 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1379,7 +1379,16 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1739,7 +1748,16 @@ class PLBartForCausalLM(PLBartPreTrainedModel): ...@@ -1739,7 +1748,16 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1151,9 +1151,18 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel): ...@@ -1151,9 +1151,18 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1147,9 +1147,18 @@ class RemBertForCausalLM(RemBertPreTrainedModel): ...@@ -1147,9 +1147,18 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1007,9 +1007,18 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -1007,9 +1007,18 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1014,9 +1014,18 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): ...@@ -1014,9 +1014,18 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1560,9 +1560,18 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel): ...@@ -1560,9 +1560,18 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if input_shape_ids is not None: if input_shape_ids is not None:
input_shape_ids = input_shape_ids[:, -1:] input_shape_ids = input_shape_ids[:, -1:]
if input_pronunciation_ids is not None: if input_pronunciation_ids is not None:
......
...@@ -1178,9 +1178,18 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel): ...@@ -1178,9 +1178,18 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -963,7 +963,16 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): ...@@ -963,7 +963,16 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -2508,7 +2508,16 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): ...@@ -2508,7 +2508,16 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
......
...@@ -1727,9 +1727,18 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod ...@@ -1727,9 +1727,18 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
encoder_outputs=None, encoder_outputs=None,
**kwargs, **kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
......
...@@ -1810,9 +1810,18 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1810,9 +1810,18 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_outputs=None, encoder_outputs=None,
**kwargs, **kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
......
...@@ -1003,7 +1003,16 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): ...@@ -1003,7 +1003,16 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1307,9 +1307,18 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel): ...@@ -1307,9 +1307,18 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
encoder_outputs=None, encoder_outputs=None,
**kwargs, **kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
......
...@@ -1810,9 +1810,17 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1810,9 +1810,17 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
attention_mask=None, attention_mask=None,
**kwargs, **kwargs,
): ):
# cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
......
...@@ -851,21 +851,30 @@ class XGLMForCausalLM(XGLMPreTrainedModel): ...@@ -851,21 +851,30 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
): ):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation # create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
else: else:
position_ids = None position_ids = None
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1011,9 +1011,18 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): ...@@ -1011,9 +1011,18 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment