Unverified Commit 7628b3a0 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Idefics: generate fix (#29320)

parent 2ce56d35
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Idefics model.""" """ PyTorch Idefics model."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -187,35 +187,6 @@ def expand_inputs_for_generation( ...@@ -187,35 +187,6 @@ def expand_inputs_for_generation(
return input_ids, model_kwargs return input_ids, model_kwargs
def update_model_kwargs_for_generation(outputs, model_kwargs):
# must have this key set to at least None
if "past_key_values" in outputs:
model_kwargs["past_key_values"] = outputs.past_key_values
else:
model_kwargs["past_key_values"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention masks
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
...@@ -1580,9 +1551,26 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): ...@@ -1580,9 +1551,26 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
): ):
return expand_inputs_for_generation(*args, **model_kwargs) return expand_inputs_for_generation(*args, **model_kwargs)
@staticmethod def _update_model_kwargs_for_generation(
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder): self,
return update_model_kwargs_for_generation(outputs, model_kwargs) outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, standardize_cache_format, model_inputs
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment