Unverified Commit e6c9053f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up `scatter_patch_features` (#15559)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 43ed4143
...@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict): ...@@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
...@@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True) num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs( return Gemma3ImagePixelInputs(
type="pixel_values", type="pixel_values",
...@@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False): return scatter_patch_features(
return image_features image_features,
image_input["embed_is_patch"],
return flatten_2d_lists( )
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
......
...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
...@@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: NestedTensors data: Union[torch.Tensor, list[torch.Tensor]]
""" """
A tensor of shape `(num_images, total_image_feature_size, hidden_size)` A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)` or a list of tensors of shape `(total_image_feature_size, hidden_size)`
...@@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
...@@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input( def _process_image_input(
self, self,
image_input: InternVLImageInputs, image_input: InternVLImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
...@@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False) if image_input["type"] != "pixel_values":
or image_input["type"] != "pixel_values"):
return image_features return image_features
return flatten_2d_lists( return scatter_patch_features(
scatter_patch_features(*args) for args in zip( image_features,
image_features, image_input["embed_is_patch"],
image_input["embed_is_patch"], )
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
......
...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
...@@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
...@@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False) if image_input["type"] != "pixel_values_pixtral":
or image_input["type"] != "pixel_values_pixtral"):
# The path is used for pixtral (V0 only) and llava (V0/V1) # The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings return image_features
return flatten_2d_lists( return scatter_patch_features(
scatter_patch_features(*args) for args in zip( image_features,
vision_embeddings, image_input["embed_is_patch"],
image_input["embed_is_patch"], )
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
......
...@@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant) SupportsMultiModal, SupportsPP, SupportsQuant)
...@@ -72,17 +71,17 @@ POOLING_SIZE = 2 ...@@ -72,17 +71,17 @@ POOLING_SIZE = 2
class MolmoImageInputs(TypedDict): class MolmoImageInputs(TypedDict):
images: Union[torch.Tensor, list[torch.Tensor]] images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`""" """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`""" """Shape: `(batch_size * num_images, num_crops, num_patch)`"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image features correspond A boolean mask indicating which image features correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_crops, num_patch)` Shape: `(batch_size * num_images, num_crops, num_patch)`
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
...@@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict): ...@@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_crops: Union[torch.Tensor, list[torch.Tensor]]
...@@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): ...@@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
return image_features return image_features
def forward( def forward(
self, images: torch.Tensor, image_masks: torch.Tensor self,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: images: torch.Tensor,
image_masks: torch.Tensor,
) -> torch.Tensor:
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
batch_size, num_image = images.shape[:2] batch_size, num_image = images.shape[:2]
images = images.to(device=self.device, dtype=self.dtype) images = images.to(device=self.device, dtype=self.dtype)
...@@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}") f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
image_masks=image_masks, image_masks=image_masks,
...@@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _process_image_input( def _process_image_input(
self, self,
image_input: MolmoImageInputs, image_input: MolmoImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]: ) -> list[torch.Tensor]:
if isinstance(image_input["images"], list): images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once # Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True) images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if (image_masks := image_masks_flat = (None if image_masks is None else flatten_bn(
image_input["image_masks"]) is None image_masks, concat=True))
else flatten_bn(image_masks, concat=True))
image_features_flat = self.vision_backbone( image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0), images=images_flat.unsqueeze(0),
...@@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
).squeeze(0) ).squeeze(0)
# Reconstruct the batch dimension # Reconstruct the batch dimension
image_features = image_features_flat.split( num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_input["num_crops"].sum(-1).tolist()) image_features = image_features_flat.split(num_crops_per_image)
else: else:
image_features = self.vision_backbone( image_features = self.vision_backbone(
images=image_input["images"], images=images,
image_masks=image_input["image_masks"], image_masks=image_masks,
) )
return image_features # Only the features corresponding to patch tokens are relevant
return [
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one item in the batch:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]
feat_is_patch (from HF processor):
[ True True False True True False ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image = num_crops.tolist()
feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
features = torch.cat([
feats[f_is_patch] feats[f_is_patch]
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image) for feats, f_is_patch in zip(image_features, feat_is_patch)
]) ]
return scatter_patch_features(features, embed_is_patch)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
...@@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
return flatten_2d_lists( return scatter_patch_features(
self._get_mm_embeds(*args) for args in zip( image_features,
image_features, image_input["embed_is_patch"],
image_input["feat_is_patch"], )
image_input["num_crops"],
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
......
...@@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config) cached_tokenizer_from_config)
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
...@@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict): ...@@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
...@@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralImagePixelInputs( return PixtralImagePixelInputs(
type="pixel_values", type="pixel_values",
images=flatten_bn(images), images=flatten_bn(images),
...@@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False): return scatter_patch_features(
return image_features image_features,
image_input["embed_is_patch"],
return flatten_2d_lists( )
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch import torch
...@@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs( ...@@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features( def scatter_patch_features(
features: torch.Tensor, patches: Union[torch.Tensor, Sequence[torch.Tensor]],
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]], embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
""" """
Scatter the patch features into a contiguous tensor that corresponds Scatter the patch features into a contiguous tensor that corresponds
...@@ -165,8 +166,8 @@ def scatter_patch_features( ...@@ -165,8 +166,8 @@ def scatter_patch_features(
can be filtered out by :func`select_patch_features`. can be filtered out by :func`select_patch_features`.
Args: Args:
features: The patch features, concatenated across each image. patches: The patch features for each image.
Shape: `(num_patch, feature_depth)` Shape: `(num_images, <patch_dims>, feature_depth)`
embed_is_patch: A boolean mask indicating which image embeddings embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image. correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)` Shape: `(num_images, num_embeds)`
...@@ -194,21 +195,21 @@ def scatter_patch_features( ...@@ -194,21 +195,21 @@ def scatter_patch_features(
The resulting embedding tensor is: The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ] [ nan p1 p2 nan p3 p4 nan nan ]
""" """
num_embeds_per_image = [ if len(patches) != len(embed_is_patch):
e_is_patch.numel() for e_is_patch in embed_is_patch raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
] f"{len(embed_is_patch)=}")
if isinstance(embed_is_patch, torch.Tensor):
embed_is_patch_flat = embed_is_patch.view(-1) def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
else: embed_one = patches_one.new_full(
embed_is_patch_flat = torch.cat(embed_is_patch) (e_is_patch.shape[0], patches_one.shape[-1]),
fill_value=torch.nan,
embeds_flat = features.new_full( )
(sum(num_embeds_per_image), features.shape[-1]), embed_one[e_is_patch] = patches_one.flatten(0, -2)
fill_value=torch.nan, return embed_one
)
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2) return tuple(
get_embed_one(patches_one, e_is_patch)
return embeds_flat.split(num_embeds_per_image) for patches_one, e_is_patch in zip(patches, embed_is_patch))
def select_patch_features( def select_patch_features(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment