Unverified Commit bb103b29 authored by kYLe's avatar kYLe Committed by GitHub
Browse files

[Bugfix] Added `embed_is_patch` mask for fuyu model (#15731)


Signed-off-by: default avatarKyle Huang <kylhuang@nvidia.com>
parent 248e76c4
......@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -39,10 +39,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
......@@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
class FuyuProcessingInfo(BaseProcessingInfo):
......@@ -183,6 +190,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _apply_hf_processor_tokens_only(
......@@ -202,7 +222,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
return dict(image_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
......@@ -301,11 +322,15 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches)
return FuyuImagePatchInputs(
......@@ -313,6 +338,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
)
return None
......@@ -333,7 +359,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
#return vision_embeddings
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
......@@ -343,8 +374,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
input_ids, inputs_embeds,
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment