Unverified Commit dcc49d8a authored by Billy Bradley's avatar Billy Bradley Committed by GitHub
Browse files

In assisted decoding, pass model_kwargs to model's forward call (fix...

In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
parent 1e3c9dda
...@@ -1297,6 +1297,43 @@ class GenerationMixin: ...@@ -1297,6 +1297,43 @@ class GenerationMixin:
UserWarning, UserWarning,
) )
def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
if self.config.is_encoder_decoder:
key = "decoder_attention_mask"
else:
key = "attention_mask"
if key not in model_kwargs:
return model_kwargs
mask = model_kwargs[key]
mask_extension_length = new_mask_length - mask.shape[1]
if mask_extension_length < 0:
raise ValueError("Cannot extend attention mask to a length less than it already is")
model_kwargs[key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
dim=-1,
)
return model_kwargs
def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs
token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
extension_length = new_length - token_type_ids.shape[1]
token_type_copies = final_token_type.repeat(1, extension_length)
model_kwargs["token_type_ids"] = torch.cat(
[model_kwargs["token_type_ids"], token_type_copies],
dim=-1,
)
return model_kwargs
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -4441,47 +4478,21 @@ class GenerationMixin: ...@@ -4441,47 +4478,21 @@ class GenerationMixin:
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model. # we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Run a forward pass on the candidate sequence # 2.1. Prepare the model inputs
if "past_key_values" in model_kwargs: candidate_kwargs = copy.copy(model_kwargs)
model_attn = torch.ones_like(candidate_input_ids) candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if self.config.is_encoder_decoder:
outputs = self( model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
decoder_input_ids=model_input_ids,
decoder_attention_mask=model_attn, # 2.2. Run a forward pass on the candidate sequence
past_key_values=model_kwargs["past_key_values"], outputs = self(
encoder_outputs=model_kwargs["encoder_outputs"], **model_inputs,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=True, )
)
else:
outputs = self(
model_input_ids,
attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
# 2.2. Process the new logits # 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0: if len(logits_processor) > 0:
for i in range(candidate_length): for i in range(candidate_length):
......
...@@ -483,9 +483,18 @@ class BarkCausalModel(BarkPreTrainedModel): ...@@ -483,9 +483,18 @@ class BarkCausalModel(BarkPreTrainedModel):
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if past_key_values is not None: if past_key_values is not None:
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
seq_len = input_ids.shape[1] seq_len = input_ids.shape[1]
input_ids = input_ids[:, [-1]] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# input_embeds have already been used and is not required anymore # input_embeds have already been used and is not required anymore
input_embeds = None input_embeds = None
...@@ -507,7 +516,7 @@ class BarkCausalModel(BarkPreTrainedModel): ...@@ -507,7 +516,7 @@ class BarkCausalModel(BarkPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
else: else:
position_ids = None position_ids = None
......
...@@ -1443,7 +1443,16 @@ class BartForConditionalGeneration(BartPreTrainedModel): ...@@ -1443,7 +1443,16 @@ class BartForConditionalGeneration(BartPreTrainedModel):
): ):
# cut decoder_input_ids if past_key_values is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1934,7 +1943,16 @@ class BartForCausalLM(BartPreTrainedModel): ...@@ -1934,7 +1943,16 @@ class BartForCausalLM(BartPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1282,7 +1282,16 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1282,7 +1282,16 @@ class BertLMHeadModel(BertPreTrainedModel):
# cut decoder_input_ids if past_key_values is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
......
...@@ -993,9 +993,18 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -993,9 +993,18 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -2628,9 +2628,18 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): ...@@ -2628,9 +2628,18 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -2627,7 +2627,16 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2627,7 +2627,16 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
): ):
# cut decoder_input_ids if past_key_values is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
......
...@@ -729,9 +729,18 @@ class BioGptForCausalLM(BioGptPreTrainedModel): ...@@ -729,9 +729,18 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
): ):
# only last token for inputs_ids if past is defined in kwargs # only last tokens for inputs_ids if past is defined in kwargs
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
......
...@@ -1392,7 +1392,16 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1392,7 +1392,16 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1622,7 +1631,16 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1622,7 +1631,16 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -1359,7 +1359,16 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1359,7 +1359,16 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1589,7 +1598,16 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1589,7 +1598,16 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
......
...@@ -920,7 +920,16 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): ...@@ -920,7 +920,16 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
# cut decoder_input_ids if past_key_values is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
......
...@@ -844,9 +844,18 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -844,9 +844,18 @@ class BloomForCausalLM(BloomPreTrainedModel):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
# only last token for input_ids if past is not None # only last tokens for input_ids if past is not None
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]: if past_key_values[0][0].shape[0] == input_ids.shape[0]:
......
...@@ -1542,9 +1542,18 @@ class CamembertForCausalLM(CamembertPreTrainedModel): ...@@ -1542,9 +1542,18 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -617,11 +617,20 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -617,11 +617,20 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # Omit tokens covered by past_key_values
if past_key_values: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -631,7 +640,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -631,7 +640,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
......
...@@ -526,9 +526,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -526,9 +526,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last tokens for inputs_ids if past is defined in kwargs
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1) past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
......
...@@ -1009,9 +1009,18 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -1009,9 +1009,18 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -843,8 +843,17 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): ...@@ -843,8 +843,17 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
if past_key_values: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
...@@ -852,7 +861,7 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): ...@@ -852,7 +861,7 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
...@@ -1667,9 +1667,18 @@ class ElectraForCausalLM(ElectraPreTrainedModel): ...@@ -1667,9 +1667,18 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1223,7 +1223,16 @@ class ErnieForCausalLM(ErniePreTrainedModel): ...@@ -1223,7 +1223,16 @@ class ErnieForCausalLM(ErniePreTrainedModel):
# cut decoder_input_ids if past_key_values is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
......
...@@ -1228,7 +1228,16 @@ class FalconForCausalLM(FalconPreTrainedModel): ...@@ -1228,7 +1228,16 @@ class FalconForCausalLM(FalconPreTrainedModel):
**kwargs, **kwargs,
) -> dict: ) -> dict:
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
...@@ -1236,7 +1245,7 @@ class FalconForCausalLM(FalconPreTrainedModel): ...@@ -1236,7 +1245,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -input_ids.shape[1] :]
return { return {
"input_ids": input_ids, "input_ids": input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment