Unverified Commit 7628b3a0 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Idefics: generate fix (#29320)

parent 2ce56d35
......@@ -19,7 +19,7 @@
# limitations under the License.
""" PyTorch Idefics model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -187,35 +187,6 @@ def expand_inputs_for_generation(
return input_ids, model_kwargs
def update_model_kwargs_for_generation(outputs, model_kwargs):
# must have this key set to at least None
if "past_key_values" in outputs:
model_kwargs["past_key_values"] = outputs.past_key_values
else:
model_kwargs["past_key_values"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention masks
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
......@@ -1580,9 +1551,26 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
):
return expand_inputs_for_generation(*args, **model_kwargs)
@staticmethod
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
return update_model_kwargs_for_generation(outputs, model_kwargs)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, standardize_cache_format, model_inputs
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
@staticmethod
def _reorder_cache(past, beam_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment