Unverified Commit 2e32f5d2 authored by B-201's avatar B-201 Committed by GitHub
Browse files

[Bugfix] Fix Idefics3 fails during multi-image inference (#11080)


Signed-off-by: default avatarB-201 <Joy25810@foxmail.com>
parent 61b1d2f6
...@@ -60,7 +60,8 @@ class Idefics3ImagePixelInputs(TypedDict): ...@@ -60,7 +60,8 @@ class Idefics3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Shape: `(batch_size * num_images * num_patches,
num_channels, height, width)`
""" """
pixel_attention_mask: Optional[torch.BoolTensor] pixel_attention_mask: Optional[torch.BoolTensor]
...@@ -520,13 +521,17 @@ class Idefics3Model(nn.Module): ...@@ -520,13 +521,17 @@ class Idefics3Model(nn.Module):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
return Idefics3ImagePixelInputs(type="pixel_values", if isinstance(pixel_values, list):
data=self._validate_pixel_values( pixel_values = torch.cat(pixel_values, dim=1)
flatten_bn(pixel_values, pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
concat=True)), else:
pixel_attention_mask=flatten_bn( pixel_values = flatten_bn(pixel_values)
pixel_attention_mask, pixel_attention_mask = flatten_bn(pixel_attention_mask)
concat=True))
return Idefics3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask)
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment