Unverified Commit 9515c208 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up processing logic (#37541)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c63ca2b2
......@@ -1221,49 +1221,33 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing
num_videos: int,
overrides: VideoDummyOptions | None = None,
):
if overrides:
if overrides.num_frames:
if overrides.num_frames > num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
overrides.num_frames,
num_frames,
)
num_frames = min(num_frames, overrides.num_frames)
if overrides.width:
if overrides.width > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
overrides.width,
width,
)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height,
height,
)
height = min(height, overrides.height)
num_frames = max(num_frames, 2) # ernie4.5-vl requires at least 2 frames
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
# ernie4.5-vl requires at least 2 frames
num_frames = max(num_frames, 2)
if overrides and overrides.num_frames:
overrides.num_frames = max(overrides.num_frames, 2)
videos = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=num_videos,
overrides=overrides,
)
videos = [v.copy() for v in videos]
video_items = []
for i in range(num_videos):
for video in videos:
video_num_frames = video.shape[0]
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"duration": video_num_frames / 2.0,
"total_num_frames": video_num_frames,
"frames_indices": list(range(video_num_frames)),
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
video_items.append((video, video_metadata))
return video_items
......
......@@ -1206,49 +1206,32 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
if overrides:
if overrides.num_frames:
if overrides.num_frames > num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
overrides.num_frames,
num_frames,
)
num_frames = min(num_frames, overrides.num_frames)
if overrides.width:
if overrides.width > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
overrides.width,
width,
)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height,
height,
)
height = min(height, overrides.height)
num_frames = max(num_frames, 2) # GLM 4.6V requires 2 frames
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
# GLM 4.6V requires at least 2 frames
num_frames = max(num_frames, 2)
if overrides and overrides.num_frames:
overrides.num_frames = max(overrides.num_frames, 2)
videos = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=num_videos,
overrides=overrides,
)
videos = [v.copy() for v in videos]
video_items = []
for i in range(num_videos):
for video in videos:
video_num_frames = video.shape[0]
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"duration": video_num_frames / 2.0,
"total_num_frames": video_num_frames,
"frames_indices": list(range(video_num_frames)),
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
video_items.append((video, video_metadata))
return video_items
......
......@@ -8,14 +8,13 @@
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Mapping, Sequence
import torch
from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems
from vllm.multimodal.inputs import BatchedTensorInputs
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
......@@ -25,7 +24,6 @@ from vllm.multimodal.processing.processor import (
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.transformers_utils.processors.h2ovl import H2OVLImageProcessor, H2OVLProcessor
......@@ -86,15 +84,12 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
def _get_prompt_updates(
def _get_prompt_repl_image(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
hf_processor: H2OVLProcessor,
out_mm_data: BatchedTensorInputs,
):
if "image_num_patches" in out_mm_data:
image_num_patches = out_mm_data["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
......@@ -130,13 +125,11 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
return [
PromptReplacement(
return PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_internvl,
)
]
def _cached_apply_hf_processor(
self,
......
......@@ -27,6 +27,7 @@ from vllm.model_executor.models.intern_vit import (
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
......@@ -238,11 +239,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
def _get_image_fields_config(self, hf_inputs: BatchFeature):
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
......@@ -255,15 +252,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
def _get_mm_fields_config(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
) -> Mapping[str, MultiModalFieldConfig]:
return self._get_image_fields_config(hf_inputs)
out_mm_data = out_mm_kwargs.get_data()
def _get_prompt_repl_image(
self,
mm_items: MultiModalDataItems,
hf_processor: InternVLProcessor,
out_mm_data: BatchedTensorInputs,
):
if "image_num_patches" in out_mm_data:
image_num_patches = out_mm_data["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
......@@ -296,12 +297,23 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
return [
PromptReplacement(
return PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_internvl,
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
return [
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
]
......@@ -455,44 +467,35 @@ class InternVLMultiModalProcessor(
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
if self.info.ctx_video_token:
def _get_video_fields_config(self, hf_inputs: BatchFeature):
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict(
return dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
else:
video_fields = {}
return image_fields | video_fields
def _get_prompt_updates(
def _get_mm_fields_config(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
prompt_repl = super()._get_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
)
if self.info.ctx_video_token is None:
return prompt_repl
) -> Mapping[str, MultiModalFieldConfig]:
fields = self._get_image_fields_config(hf_inputs)
if self.info.ctx_video_token:
fields |= self._get_video_fields_config(hf_inputs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
return fields
out_mm_data = out_mm_kwargs.get_data()
def _get_prompt_repl_video(
self,
mm_items: MultiModalDataItems,
hf_processor: InternVLProcessor,
out_mm_data: BatchedTensorInputs,
):
if "video_num_patches" in out_mm_data:
video_num_patches = out_mm_data["video_num_patches"]
assert isinstance(video_num_patches, torch.Tensor)
......@@ -507,14 +510,30 @@ class InternVLMultiModalProcessor(
return hf_processor.get_video_repl(num_patches)
return [
*prompt_repl,
PromptReplacement(
return PromptReplacement(
modality="video",
target="<video>",
replacement=get_video_replacement_internvl,
),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
prompt_repls = [
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
]
if self.info.ctx_video_token is not None:
prompt_repls.append(
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
)
return prompt_repls
@MULTIMODAL_REGISTRY.register_processor(
......
......@@ -1913,22 +1913,32 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
video = np.full((num_frames, height, width, 3), 255, dtype=np.uint8)
videos = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=num_videos,
overrides=overrides,
)
videos = [v.copy() for v in videos]
video_items = []
for i in range(num_videos):
for video in videos:
video_num_frames = video.shape[0]
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": list(range(num_frames)),
"duration": video_num_frames / 2.0,
"total_num_frames": video_num_frames,
"frames_indices": list(range(video_num_frames)),
"video_backend": "decord",
"do_sample_frames": False,
"height": height,
"width": width,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
video_items.append((video, video_metadata))
return video_items
......
......@@ -7,7 +7,7 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Mapping, Sequence
from collections.abc import Mapping
import torch
import torch.nn as nn
......@@ -16,7 +16,10 @@ from transformers import PretrainedConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalDataDict,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
......@@ -24,7 +27,6 @@ from vllm.multimodal.parse import (
)
from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.processors.internvl import InternVLImageProcessor
......@@ -100,15 +102,12 @@ class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo])
class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_updates(
def _get_prompt_repl_image(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
hf_processor: NVLMProcessor,
out_mm_data: BatchedTensorInputs,
):
if "image_num_patches" in out_mm_data:
image_num_patches = out_mm_data["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
......@@ -146,13 +145,11 @@ class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo
)
# See note in dummy data regarding why we have the extra newline
return [
PromptReplacement(
return PromptReplacement(
modality="image",
target="<image>\n",
replacement=get_replacement_nvlm,
)
]
@MULTIMODAL_REGISTRY.register_processor(
......
......@@ -931,20 +931,30 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
videos = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=num_videos,
overrides=overrides,
)
videos = [v.copy() for v in videos]
video_items = []
for i in range(num_videos):
for video in videos:
video_num_frames = video.shape[0]
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"duration": video_num_frames / 2.0,
"total_num_frames": video_num_frames,
"frames_indices": list(range(video_num_frames)),
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
video_items.append((video, video_metadata))
return video_items
......
......@@ -7,12 +7,12 @@
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping
from typing import Annotated, Literal, TypeAlias
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
......@@ -24,24 +24,8 @@ from vllm.model_executor.models.intern_vit import (
InternVisionPatchModel,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.internvl import (
InternVLImageProcessor,
......@@ -50,6 +34,11 @@ from vllm.transformers_utils.processors.internvl import (
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .internvl import (
BaseInternVLDummyInputsBuilder,
BaseInternVLMultiModalProcessor,
BaseInternVLProcessingInfo,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
......@@ -98,7 +87,7 @@ SkyworkR1VImageInputs: TypeAlias = (
)
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
class SkyworkR1VProcessingInfo(BaseInternVLProcessingInfo):
def get_image_processor(self, **kwargs):
config = self.get_hf_config()
vision_config = config.vision_config
......@@ -128,46 +117,6 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
image_seq_length=image_seq_length,
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: InternVLProcessor,
) -> int:
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor = processor.image_processor
base_size = image_processor.image_size
target_ratios = processor.resolve_target_ratios()
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
......@@ -196,102 +145,10 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn
}
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.ctx_image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches
),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
if "image_num_patches" in out_mm_data:
image_num_patches = out_mm_data["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_data:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_data["image_embeds"])
else:
image_num_patches = []
def get_replacement_skyworkr1v(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_skyworkr1v,
)
]
@MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor,
BaseInternVLMultiModalProcessor,
info=SkyworkR1VProcessingInfo,
dummy_inputs=SkyworkR1VDummyInputsBuilder,
dummy_inputs=BaseInternVLDummyInputsBuilder,
)
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment