Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
...@@ -867,7 +867,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -867,7 +867,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
**kwargs: object, **kwargs: object,
): ) -> Qwen2VLImageProcessor:
return cached_image_processor_from_config( return cached_image_processor_from_config(
self.ctx.model_config, self.ctx.model_config,
**self._get_image_processor_kwargs(min_pixels=min_pixels, **self._get_image_processor_kwargs(min_pixels=min_pixels,
...@@ -886,7 +886,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -886,7 +886,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
) -> Mapping[str, int]: ) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), "image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len), "video": self.get_max_video_tokens(seq_len, mm_counts),
} }
def _get_vision_info( def _get_vision_info(
...@@ -1002,10 +1002,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -1002,10 +1002,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return num_frames return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(
mm_config = self.ctx.get_mm_config() self,
max_images = mm_config.get_limit_per_prompt("image") seq_len: int,
max_videos = mm_config.get_limit_per_prompt("video") mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
...@@ -1015,13 +1018,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -1015,13 +1018,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len), num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None, image_processor=None,
) )
...@@ -1043,7 +1051,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): ...@@ -1043,7 +1051,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { mm_data = {
"image": "image":
......
...@@ -21,9 +21,10 @@ import torch.nn as nn ...@@ -21,9 +21,10 @@ import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_in_doc_build from vllm.utils import is_in_doc_build
from .interfaces import (has_inner_state, is_attention_free, is_hybrid, from .interfaces import (has_inner_state, has_noops, is_attention_free,
supports_cross_encoding, supports_multimodal, is_hybrid, supports_cross_encoding,
supports_pp, supports_transcription, supports_v0_only) supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model from .interfaces_base import is_text_generation_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,6 +35,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -34,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name # baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name # baichuan-13b, lower case 'c' in the class name
...@@ -44,7 +46,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -44,7 +46,7 @@ _TEXT_GENERATION_MODELS = {
"CohereForCausalLM": ("commandr", "CohereForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"), "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
...@@ -71,6 +73,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -71,6 +73,7 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
...@@ -118,7 +121,7 @@ _EMBEDDING_MODELS = { ...@@ -118,7 +121,7 @@ _EMBEDDING_MODELS = {
"RobertaModel": ("roberta", "RobertaEmbeddingModel"), "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
...@@ -160,6 +163,7 @@ _CROSS_ENCODER_MODELS = { ...@@ -160,6 +163,7 @@ _CROSS_ENCODER_MODELS = {
_MULTIMODAL_MODELS = { _MULTIMODAL_MODELS = {
# [Decoder-only] # [Decoder-only]
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
"AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
...@@ -176,6 +180,7 @@ _MULTIMODAL_MODELS = { ...@@ -176,6 +180,7 @@ _MULTIMODAL_MODELS = {
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
"MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMO": ("minicpmo", "MiniCPMO"),
"MiniCPMV": ("minicpmv", "MiniCPMV"), "MiniCPMV": ("minicpmv", "MiniCPMV"),
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
...@@ -190,6 +195,8 @@ _MULTIMODAL_MODELS = { ...@@ -190,6 +195,8 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder] # [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
} }
...@@ -200,8 +207,8 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -200,8 +207,8 @@ _SPECULATIVE_DECODING_MODELS = {
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
} }
_FALLBACK_MODEL = { _TRANSFORMERS_MODELS = {
"TransformersModel": ("transformers", "TransformersModel"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
} }
# yapf: enable # yapf: enable
...@@ -211,7 +218,7 @@ _VLLM_MODELS = { ...@@ -211,7 +218,7 @@ _VLLM_MODELS = {
**_CROSS_ENCODER_MODELS, **_CROSS_ENCODER_MODELS,
**_MULTIMODAL_MODELS, **_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS, **_SPECULATIVE_DECODING_MODELS,
**_FALLBACK_MODEL, **_TRANSFORMERS_MODELS,
} }
# This variable is used as the args for subprocess.run(). We # This variable is used as the args for subprocess.run(). We
...@@ -234,6 +241,7 @@ class _ModelInfo: ...@@ -234,6 +241,7 @@ class _ModelInfo:
has_inner_state: bool has_inner_state: bool
is_attention_free: bool is_attention_free: bool
is_hybrid: bool is_hybrid: bool
has_noops: bool
supports_transcription: bool supports_transcription: bool
supports_v0_only: bool supports_v0_only: bool
...@@ -251,6 +259,7 @@ class _ModelInfo: ...@@ -251,6 +259,7 @@ class _ModelInfo:
is_hybrid=is_hybrid(model), is_hybrid=is_hybrid(model),
supports_transcription=supports_transcription(model), supports_transcription=supports_transcription(model),
supports_v0_only=supports_v0_only(model), supports_v0_only=supports_v0_only(model),
has_noops=has_noops(model),
) )
...@@ -423,9 +432,9 @@ class _ModelRegistry: ...@@ -423,9 +432,9 @@ class _ModelRegistry:
normalized_arch = list( normalized_arch = list(
filter(lambda model: model in self.models, architectures)) filter(lambda model: model in self.models, architectures))
# make sure Transformers fallback are put at the last # make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures): if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersModel") normalized_arch.append("TransformersForCausalLM")
return normalized_arch return normalized_arch
def inspect_model_cls( def inspect_model_cls(
...@@ -510,6 +519,13 @@ class _ModelRegistry: ...@@ -510,6 +519,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_hybrid return model_cls.is_hybrid
def is_noops_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_noops
def is_transcription_model( def is_transcription_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
......
...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
...@@ -203,6 +203,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -203,6 +203,18 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm",
'layers': "layer",
'mixer.Wqkv': "attention.self.qkv_proj",
'mixer.out_proj': "attention.output.dense",
'norm1': "attention.output.LayerNorm",
'mlp.fc1': "intermediate.dense",
'mlp.fc2': "output.dense",
'norm2': "output.LayerNorm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -219,8 +231,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -219,8 +231,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self._pooler = CrossEncodingPooler(config, self.classifier) self._pooler = CrossEncodingPooler(config, self.classifier)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights) bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
self.roberta.load_weights(bert_weights) self.roberta.load_weights(bert_weights)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -208,8 +208,10 @@ class SiglipMLP(nn.Module): ...@@ -208,8 +208,10 @@ class SiglipMLP(nn.Module):
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB quantization # Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() == "bitsandbytes": if quant_config and quant_config.get_name() in [
"bitsandbytes", "torchao"
]:
quantizable = True quantizable = True
else: else:
# For other quantization, we require the hidden size to be a # For other quantization, we require the hidden size to be a
......
# SPDX-License-Identifier: Apache-2.0
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class SkyworkR1VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class SkyworkR1VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
SkyworkR1VImageEmbeddingInputs]
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
def find_closest_aspect_ratio(
aspect_ratio: float,
target_ratios: list[tuple[int, int]],
*,
width: int,
height: int,
image_size: int,
) -> tuple[int, int]:
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def resolve_skyworkr1v_min_max_num(
*,
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: bool,
use_thumbnail: bool,
) -> tuple[int, int]:
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1:
max_dynamic_patch += 1
return min_dynamic_patch, max_dynamic_patch
def get_skyworkr1v_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
def calculate_skyworkr1v_targets(
*,
orig_width: int,
orig_height: int,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width=orig_width,
height=orig_height,
image_size=image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks != 1
if use_thumbnail and blocks != 1:
blocks += 1
return blocks, target_width, target_height
def dynamic_preprocess_skyworkr1v(
image: Image.Image,
*,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> list[Image.Image]:
orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
blocks, target_width, target_height = calculate_skyworkr1v_targets(
orig_width=orig_width,
orig_height=orig_height,
target_ratios=target_ratios,
image_size=image_size,
use_thumbnail=False,
)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B
def image_to_pixel_values_skyworkr1v(
image: Image.Image,
*,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
) -> torch.Tensor:
target_ratios = get_skyworkr1v_target_ratios(min_num, max_num)
transform = build_transform(input_size=input_size)
images = dynamic_preprocess_skyworkr1v(
image,
target_ratios=target_ratios,
image_size=input_size,
use_thumbnail=use_thumbnail,
)
pixel_values = torch.stack([transform(image) for image in images])
return pixel_values
class BaseSkyworkR1VProcessor(ABC):
"""
This model doesn't define its own HF processor,
so we implement our own one here.
The code to insert image tokens is based on:
https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252
"""
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
image_size: int = config.vision_config.image_size
patch_size: int = config.vision_config.patch_size
if min_dynamic_patch is None:
min_dynamic_patch = config.min_dynamic_patch
assert isinstance(min_dynamic_patch, int)
if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch
assert isinstance(max_dynamic_patch, int)
if dynamic_image_size is None:
dynamic_image_size = config.dynamic_image_size
assert isinstance(dynamic_image_size, bool)
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.image_size = image_size
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail: bool = config.use_thumbnail
@property
@abstractmethod
def image_token_id(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
raise NotImplementedError
def resolve_min_max_num(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]:
min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
is None else min_dynamic_patch)
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
is None else dynamic_image_size)
use_thumbnail = (self.use_thumbnail
if use_thumbnail is None else use_thumbnail)
return resolve_skyworkr1v_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
def resolve_target_ratios(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
return get_skyworkr1v_target_ratios(min_num, max_num)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
target_ratios = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets
)
num_patches, _, _ = calculate_skyworkr1v_targets(
orig_width=image_width,
orig_height=image_height,
image_size=self.image_size,
target_ratios=target_ratios,
use_thumbnail=self.use_thumbnail,
)
return num_patches * self.num_image_token
def _images_to_pixel_values_lst(
self,
images: list[Image.Image],
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values
)
return [
image_to_pixel_values_skyworkr1v(
image,
input_size=self.image_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
else:
pixel_values_lst = self._images_to_pixel_values_lst(
images,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs: dict[str, NestedTensors] = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
}
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
}
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
@abstractmethod
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseSkyworkR1VProcessor:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[BaseSkyworkR1VProcessor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
base_size = processor.image_size
target_ratios = processor.resolve_target_ratios()
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
image_num_patches = []
def get_replacement_skyworkr1v(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return hf_processor.get_image_repl(feature_size, num_patches)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_skyworkr1v,
)
]
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> SkyworkR1VProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor,
info=SkyworkR1VProcessingInfo,
dummy_inputs=SkyworkR1VDummyInputsBuilder)
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM'
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
ReplicatedLinear(vit_hidden_size *
int(1 / self.downsample_ratio)**2,
llm_hidden_size,
return_bias=False),
nn.GELU(),
ReplicatedLinear(llm_hidden_size,
llm_hidden_size,
return_bias=False),
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
pass
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds,
scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return SkyworkR1VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: SkyworkR1VImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
# Only one image in the current batch
if len(num_patches) == 1:
return image_embeds.view(
-1, self.config.text_config.hidden_size).unsqueeze(0)
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
return image_embeds.split(image_feature_sizes)
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
self.visual_token_mask = None
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.img_context_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
# Only required if the model is mono-architecture
if self.visual_token_mask is not None:
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = [
"action_embed", "temporal_embed", "track_embed",
"track_embed_decoder", "box_token", "cg_criterion", "cg_model",
"loc_encoder", "loc_decoder", "sam", "temporal_token",
"track_token"
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
...@@ -24,6 +24,7 @@ from transformers import AutoModel, PretrainedConfig, PreTrainedModel ...@@ -24,6 +24,7 @@ from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig) ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -42,7 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -42,7 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix) make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -109,12 +111,9 @@ def replace_linear_class( ...@@ -109,12 +111,9 @@ def replace_linear_class(
) )
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): class TransformersModel(nn.Module):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
logger.info("Using Transformers backend.") logger.info("Using Transformers backend.")
...@@ -132,9 +131,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -132,9 +131,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.quant_config = quant_config self.quant_config = quant_config
self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group self.pp_rank = self.pp_group.rank_in_group
...@@ -142,13 +138,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -142,13 +138,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Use meta device to delay allocating GPU tensors # Use meta device to delay allocating GPU tensors
with torch.device("meta"): with torch.device("meta"):
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
config, config,
attn_implementation="vllm", attn_implementation="vllm",
torch_dtype=model_config.dtype, torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
prefix = self.model.base_model_prefix
self.pipeline_parallel() self.pipeline_parallel()
self.tensor_parallel() self.tensor_parallel()
...@@ -166,32 +164,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -166,32 +164,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Attention layers # Attention layers
self.attention_instances = self.create_attention_instances() self.attention_instances = self.create_attention_instances()
# Output embeddings
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
# Initialize buffers (e.g. rotary embedding inverse frequency) # Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model) self.init_buffers(self.model)
# Move remaining meta tensors to device (should happen last) # Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model) self.meta_to_empty(self.model)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
...@@ -246,15 +224,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -246,15 +224,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer()) setattr(self.model, name, PPMissingLayer())
if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()
def tensor_parallel(self): def tensor_parallel(self):
""" """
Apply the model's tensor parallelization plan. Apply the model's tensor parallelization plan.
Currently only supports linear layers. Currently only supports linear layers.
""" """
if self.tp_size > 1 and self.config.base_model_tp_plan is None: if not self.model.supports_tp_plan:
if self.tp_size <= 1:
return
raise ValueError( raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!") f"{type(self.model)} does not support tensor parallel yet!")
...@@ -329,6 +307,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -329,6 +307,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
for child in module.children(): for child in module.children():
self.meta_to_empty(child) self.meta_to_empty(child)
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
...@@ -359,6 +340,92 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -359,6 +340,92 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params = set[str]()
for name, loaded_weight in weights:
# Use "model" instead of base_model_prefix because
# the base model attribute in vLLM is always `model`
if not name.startswith(prefix := "model."):
name = prefix + name
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -376,18 +443,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -376,18 +443,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) loader = AutoWeightsLoader(
loaded_params = set[str]() self,
for name, loaded_weight in weights: skip_prefixes=(["lm_head."]
# Necessary for some models which use remote code if self.config.tie_word_embeddings else None),
if not name.startswith(prefix := self.model.base_model_prefix): )
name = maybe_prefix(prefix, name) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -8,7 +8,6 @@ from functools import cached_property ...@@ -8,7 +8,6 @@ from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, ProcessorMixin
...@@ -160,7 +159,7 @@ class UltravoxMultiModalProcessor( ...@@ -160,7 +159,7 @@ class UltravoxMultiModalProcessor(
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# Text-only input not supported in composite processor # Text-only input not supported in composite processor
if not mm_data or not mm_data.get("audios", []): if not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode( prompt_ids = self.info.get_tokenizer().encode(
prompt, add_special_tokens=False) prompt, add_special_tokens=False)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
......
...@@ -10,12 +10,14 @@ import torch.nn as nn ...@@ -10,12 +10,14 @@ import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -156,6 +158,26 @@ class AutoWeightsLoader: ...@@ -156,6 +158,26 @@ class AutoWeightsLoader:
yield weight_qualname yield weight_qualname
def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
"""
if isinstance(module, (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LazyBatchNorm1d,
nn.LazyBatchNorm2d,
nn.LazyBatchNorm3d,
nn.SyncBatchNorm,
)):
module_state_dict = module.state_dict()
for stat_name in ("running_mean", "running_var",
"num_batches_tracked"):
child_params[stat_name] = module_state_dict[stat_name]
def _load_module( def _load_module(
self, self,
base_prefix: str, base_prefix: str,
...@@ -184,6 +206,10 @@ class AutoWeightsLoader: ...@@ -184,6 +206,10 @@ class AutoWeightsLoader:
child_modules = dict(module.named_children()) child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False)) child_params = dict(module.named_parameters(recurse=False))
# Add missing tensors the weight loader needs to be able to load
# that aren't registered as params, e.g., batchnorm statistics.
self._add_loadable_non_param_tensors(module, child_params)
for child_prefix, child_weights in self._groupby_prefix(weights): for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix) prefix = self._get_qualname(base_prefix, child_prefix)
...@@ -495,7 +521,10 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None: ...@@ -495,7 +521,10 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device if (params := next(module.parameters(), None)) is None:
return module
device = params.device
if device == torch.device("cpu"): if device == torch.device("cpu"):
return module return module
...@@ -505,6 +534,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: ...@@ -505,6 +534,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
return module return module
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
uva_available = is_uva_available()
if envs.VLLM_USE_V1:
assert uva_available, ("V1 CPU offloading requires"
" uva (pin memory) support")
uva_offloading = True
else:
uva_offloading = False
# offload parameters to CPU # offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed # use pin_memory if possible, which helps cudagraph capture speed
...@@ -523,11 +560,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: ...@@ -523,11 +560,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device='cpu', device='cpu',
pin_memory=pin_memory) pin_memory=pin_memory)
cpu_data.copy_(p.data) cpu_data.copy_(p.data)
p.data = cpu_data if not uva_offloading:
p.data = cpu_data
else:
# keep the cpu data alive
p._vllm_offloaded_cpu_data = cpu_data
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True offloaded_parameters = True
if offloaded_parameters: if offloaded_parameters and not uva_offloading:
original_forward = module.forward original_forward = module.forward
def forward(*args, **kwargs): def forward(*args, **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch import torch
...@@ -68,6 +69,9 @@ def get_vision_encoder_info( ...@@ -68,6 +69,9 @@ def get_vision_encoder_info(
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(vision_config) return CLIPEncoderInfo(vision_config)
if isinstance(vision_config, PixtralVisionConfig): if isinstance(vision_config, PixtralVisionConfig):
# Need to sneak in spatial_merge_size for Mistral3
vision_config.spatial_merge_size = getattr(hf_config,
"spatial_merge_size", 1)
return PixtralHFEncoderInfo(vision_config) return PixtralHFEncoderInfo(vision_config)
if isinstance(vision_config, SiglipVisionConfig): if isinstance(vision_config, SiglipVisionConfig):
return SiglipEncoderInfo(vision_config) return SiglipEncoderInfo(vision_config)
...@@ -154,9 +158,8 @@ def resolve_visual_encoder_outputs( ...@@ -154,9 +158,8 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features( def scatter_patch_features(
features: torch.Tensor, patches: Union[torch.Tensor, Sequence[torch.Tensor]],
num_embeds: torch.Tensor, embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
""" """
Scatter the patch features into a contiguous tensor that corresponds Scatter the patch features into a contiguous tensor that corresponds
...@@ -166,23 +169,50 @@ def scatter_patch_features( ...@@ -166,23 +169,50 @@ def scatter_patch_features(
can be filtered out by :func`select_patch_features`. can be filtered out by :func`select_patch_features`.
Args: Args:
features: The patch features, concatenated across each image. patches: The patch features for each image.
Shape: `(num_patch, feature_depth)` Shape: `(num_images, <patch_dims>, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image. correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)` Shape: `(num_images, num_embeds)`
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
embeds_flat = features.new_full( Note:
(sum(num_embeds_per_image), features.shape[-1]), The original code only considers patch tokens as feature
fill_value=torch.nan, tokens, but our processor considers all image-related tokens
) as feature tokens because the feature tokens need to be
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2) consecutive in `input_ids`.
Example:
A simplified example for one image:
.. code-block::
return embeds_flat.split(num_embeds_per_image) Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 p3 p4 ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
if len(patches) != len(embed_is_patch):
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
f"{len(embed_is_patch)=}")
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
embed_one = patches_one.new_full(
(e_is_patch.shape[0], patches_one.shape[-1]),
fill_value=torch.nan,
)
embed_one[e_is_patch] = patches_one
return embed_one
return tuple(
get_embed_one(patches_one, e_is_patch)
for patches_one, e_is_patch in zip(patches, embed_is_patch))
def select_patch_features( def select_patch_features(
......
...@@ -191,7 +191,7 @@ class SamplingMetadata: ...@@ -191,7 +191,7 @@ class SamplingMetadata:
"SamplingMetadata(" "SamplingMetadata("
f"seq_groups={self.seq_groups}, " f"seq_groups={self.seq_groups}, "
f"selected_token_indices={self.selected_token_indices}, " f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices}), ") f"categorized_sample_indices={self.categorized_sample_indices})")
def _prepare_seq_groups( def _prepare_seq_groups(
......
...@@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): ...@@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
return self.load_bytes(base64.b64decode(data)) return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> torch.Tensor: def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath) return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str: def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8') return base64.b64encode(media.numpy()).decode('utf-8')
...@@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped) return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
...@@ -736,7 +743,7 @@ class MultiModalInputs(TypedDict): ...@@ -736,7 +743,7 @@ class MultiModalInputs(TypedDict):
mm_kwargs: MultiModalKwargs mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[Optional["MultiModalHashDict"]] mm_hashes: Optional["MultiModalHashDict"]
"""The hashes of the multi-modal data.""" """The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict mm_placeholders: MultiModalPlaceholderDict
......
...@@ -295,7 +295,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): ...@@ -295,7 +295,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
ModalityDataItems[Any, Any]] Optional[ModalityDataItems[Any, Any]]]
class MultiModalDataParser: class MultiModalDataParser:
...@@ -319,7 +319,15 @@ class MultiModalDataParser: ...@@ -319,7 +319,15 @@ class MultiModalDataParser:
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.ndim == 3 return data.ndim == 3
if is_list_of(data, torch.Tensor): if is_list_of(data, torch.Tensor):
return len(data) == 0 or data[0].ndim == 2 return data[0].ndim == 2
return False
def _is_empty(self, data: object) -> TypeGuard[None]:
if isinstance(data, list):
return len(data) == 0
if isinstance(data, (np.ndarray, torch.Tensor)):
return data.size == 0
return False return False
...@@ -341,7 +349,12 @@ class MultiModalDataParser: ...@@ -341,7 +349,12 @@ class MultiModalDataParser:
def _parse_audio_data( def _parse_audio_data(
self, self,
data: ModalityData[AudioItem], data: ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
# also check single audio item with sampling rate
if self._is_empty(data) or (isinstance(data, tuple)
and self._is_empty(data[0])):
return None
if self._is_embeddings(data): if self._is_embeddings(data):
return AudioEmbeddingItems(data) return AudioEmbeddingItems(data)
...@@ -378,7 +391,10 @@ class MultiModalDataParser: ...@@ -378,7 +391,10 @@ class MultiModalDataParser:
def _parse_image_data( def _parse_image_data(
self, self,
data: ModalityData[ImageItem], data: ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if self._is_empty(data):
return None
if self._is_embeddings(data): if self._is_embeddings(data):
return ImageEmbeddingItems(data) return ImageEmbeddingItems(data)
...@@ -396,7 +412,10 @@ class MultiModalDataParser: ...@@ -396,7 +412,10 @@ class MultiModalDataParser:
def _parse_video_data( def _parse_video_data(
self, self,
data: ModalityData[VideoItem], data: ModalityData[VideoItem],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if self._is_empty(data):
return None
if self._is_embeddings(data): if self._is_embeddings(data):
return VideoEmbeddingItems(data) return VideoEmbeddingItems(data)
...@@ -427,6 +446,8 @@ class MultiModalDataParser: ...@@ -427,6 +446,8 @@ class MultiModalDataParser:
if k not in subparsers: if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}") raise ValueError(f"Unsupported modality: {k}")
mm_items[k] = subparsers[k](v) # ignore empty embedding data
if (parsed_data := subparsers[k](v)) is not None:
mm_items[k] = parsed_data
return mm_items return mm_items
\ No newline at end of file
...@@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, ...@@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import torch import torch
from cachetools import LRUCache
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves ...@@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
......
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, TypeVar, cast from typing import Generic, NamedTuple, Optional, TypeVar, cast
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image from PIL import Image
import vllm.envs as envs import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs) MultiModalInputs, MultiModalKwargs,
MultiModalPlaceholderDict)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,6 +31,20 @@ class ProcessorInputs: ...@@ -31,6 +31,20 @@ class ProcessorInputs:
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargs
multi_modal_placeholders: MultiModalPlaceholderDict
_I = TypeVar("_I", bound=BaseProcessingInfo) _I = TypeVar("_I", bound=BaseProcessingInfo)
...@@ -146,17 +160,19 @@ class MultiModalProfiler(Generic[_I]): ...@@ -146,17 +160,19 @@ class MultiModalProfiler(Generic[_I]):
def get_and_validate_mm_inputs( def get_and_validate_mm_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> tuple[MultiModalInputs, Mapping[str, int]]: ) -> tuple[MultiModalInputs, Mapping[str, int]]:
mm_counts = self.get_mm_limits() if mm_counts is None:
mm_counts = self.get_mm_limits()
info = self.processing_info info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts) seq_len, mm_counts)
if mm_counts.keys() != mm_max_tokens_per_item.keys(): if mm_counts.keys() - mm_max_tokens_per_item.keys():
raise AssertionError( raise AssertionError(
"The keys returned by `get_supported_mm_limits` " "The keys returned by `get_supported_mm_limits` "
f"({set(mm_counts.keys())}) should be the same as those " f"({set(mm_counts.keys())}) should be a subset of those "
"returned by `get_mm_max_tokens_per_item` " "returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})") f"({set(mm_max_tokens_per_item.keys())})")
...@@ -182,11 +198,9 @@ class MultiModalProfiler(Generic[_I]): ...@@ -182,11 +198,9 @@ class MultiModalProfiler(Generic[_I]):
def get_encoder_dummy_data( def get_encoder_dummy_data(
self, self,
seq_len: int, seq_len: int,
) -> DummyData: mm_counts: Optional[Mapping[str, int]] = None,
# Avoid circular import ) -> DummyEncoderData:
from vllm.sequence import SequenceData mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of # For encoder-decoder models, use encoder prompt token ids instead of
...@@ -197,21 +211,17 @@ class MultiModalProfiler(Generic[_I]): ...@@ -197,21 +211,17 @@ class MultiModalProfiler(Generic[_I]):
num_tokens_to_pad = max(total_len, seq_len) - total_len num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData( return DummyEncoderData(encoder_prompt_token_ids)
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)
def get_decoder_dummy_data( def get_decoder_dummy_data(
self, self,
seq_len: int, seq_len: int,
) -> DummyData: mm_counts: Optional[Mapping[str, int]] = None,
# Avoid circular import ) -> DummyDecoderData:
from vllm.sequence import SequenceData (
mm_inputs,
(mm_inputs, total_placeholders_by_modality total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len) ) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
...@@ -231,16 +241,11 @@ class MultiModalProfiler(Generic[_I]): ...@@ -231,16 +241,11 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`.", seq_len, total_len, "and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality) total_placeholders_by_modality)
return DummyData( if total_len < seq_len:
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), prompt_token_ids.extend([0] * (seq_len - total_len))
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData( return DummyDecoderData(
seq_data=SequenceData.from_seqs(prompt_token_ids), prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=mm_inputs["mm_placeholders"], multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )
...@@ -21,7 +21,8 @@ from .image import ImagePlugin ...@@ -21,7 +21,8 @@ from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache) ProcessingCache)
from .profiling import BaseDummyInputsBuilder, MultiModalProfiler from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
DummyEncoderData, MultiModalProfiler)
from .video import VideoPlugin from .video import VideoPlugin
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -256,10 +257,7 @@ class MultiModalRegistry: ...@@ -256,10 +257,7 @@ class MultiModalRegistry:
on underlying model configuration. on underlying model configuration.
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config) processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config,
tokenizer,
disable_cache=True)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config) mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item( return processor.info.get_mm_max_tokens_per_item(
...@@ -373,10 +371,7 @@ class MultiModalRegistry: ...@@ -373,10 +371,7 @@ class MultiModalRegistry:
This should be called after :meth:`init_mm_limits_per_prompt`. This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config) processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits() return profiler.get_mm_limits()
...@@ -436,8 +431,8 @@ class MultiModalRegistry: ...@@ -436,8 +431,8 @@ class MultiModalRegistry:
def create_processor( def create_processor(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: AnyTokenizer,
*, *,
tokenizer: Optional[AnyTokenizer] = None,
disable_cache: Optional[bool] = None, disable_cache: Optional[bool] = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """
...@@ -446,6 +441,8 @@ class MultiModalRegistry: ...@@ -446,6 +441,8 @@ class MultiModalRegistry:
See also: See also:
:ref:`mm-processing` :ref:`mm-processing`
""" """
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None: if disable_cache is None:
disable_cache = model_config.disable_mm_preprocessor_cache disable_cache = model_config.disable_mm_preprocessor_cache
...@@ -456,3 +453,51 @@ class MultiModalRegistry: ...@@ -456,3 +453,51 @@ class MultiModalRegistry:
cache = None if disable_cache else self._processing_cache cache = None if disable_cache else self._processing_cache
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(token_ids)} tokens instead.")
return dummy_data
def get_encoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
logger.warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(token_ids)} tokens instead.")
return dummy_data
...@@ -13,8 +13,6 @@ import os ...@@ -13,8 +13,6 @@ import os
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
...@@ -22,8 +20,6 @@ from .image import ImageEmbeddingMediaIO, ImageMediaIO ...@@ -22,8 +20,6 @@ from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange from .inputs import PlaceholderRange
from .video import VideoMediaIO from .video import VideoMediaIO
logger = init_logger(__name__)
_M = TypeVar("_M") _M = TypeVar("_M")
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -298,121 +294,6 @@ def encode_video_base64(frames: npt.NDArray) -> str: ...@@ -298,121 +294,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames) return video_io.encode_base64(frames)
# Utilities for input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> list[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: list[int],
*,
placeholder_token_id: int,
repeat_count: Union[int, list[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None:
new_prompt = None
else:
placeholder_token_str = tokenizer.decode(placeholder_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
if placeholder_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str)
if placeholder_token_count < len(repeat_count):
logger.warning(
"The number of multi-modal placeholder tokens in the prompt "
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
repeat_count = repeat_count[:placeholder_token_count]
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids = list[int]()
placeholder_ranges = list[PlaceholderRange]()
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
curr_repeat_count = repeat_count[placeholder_token_idx]
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
repeat_count=curr_repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
offset = len(new_token_ids)
if pad_token_left is not None:
offset += 1
placeholder_ranges.append({
"offset": offset,
"length": curr_repeat_count,
})
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we replaced all tokens
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(
num_items: int,
item_size: int,
initial_offset: int = 0) -> list[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items)
]
def merge_and_sort_multimodal_metadata( def merge_and_sort_multimodal_metadata(
mm_positions: "MultiModalPlaceholderDict", mm_positions: "MultiModalPlaceholderDict",
mm_hashes: Optional["MultiModalHashDict"], mm_hashes: Optional["MultiModalHashDict"],
...@@ -424,14 +305,10 @@ def merge_and_sort_multimodal_metadata( ...@@ -424,14 +305,10 @@ def merge_and_sort_multimodal_metadata(
Optionally if a MultiModalHashDict is given, same operation will be Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned. applied to the object and the sorted list of hashes will be returned.
Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")
Returns: Returns:
list[str]: Sorted list of involved modalities. list[str]: List of item modalities in order of their positions in
the input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions. mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if Optional[list[str]]: Sorted list of all hashes from mm_hashes if
...@@ -445,50 +322,33 @@ def merge_and_sort_multimodal_metadata( ...@@ -445,50 +322,33 @@ def merge_and_sort_multimodal_metadata(
# For single modality, placeholder ranges and hashes are already sorted # For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly. # so we can return the list directly.
if len(modalities) == 1: if len(modalities) == 1:
if mm_hashes is None: modality = modalities[0]
return modalities, list(mm_positions[modalities[0]]), None placeholder_list = list(mm_positions[modality])
else:
return modalities, list(mm_positions[modalities[0]]), list( return [modality] * len(
mm_hashes[modalities[0]]) placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
placeholder_lists_with_modality = [(modality, mm_positions[modality])
for modality in modalities] # Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items = []
if mm_hashes is None: for modality in modalities:
sorted_placeholder_lists = sorted(placeholder_lists_with_modality, placeholder_list = list(mm_positions[modality])
key=lambda x: x[1][0]['offset']) hash_list: list[Optional[str]] = list(
sorted_hash_lists = None mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
else: None
hashes_lists = [ ] * len(placeholder_list)
mm_hashes[modality] for modality in modalities
if modality in mm_hashes for placeholder, hash_value in zip(placeholder_list, hash_list):
] all_items.append((modality, placeholder, hash_value))
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
hashes_lists), # Sort all items by offset
key=lambda x: x[0][1][0]['offset']) all_items.sort(key=lambda x: x[1]['offset'])
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
sorted_placeholder_lists = list(sorted_placeholder_tuple) # Split into separate lists
sorted_hash_lists = list(sorted_hash_tuple) sorted_modalities = [item[0] for item in all_items]
merged_placeholders = [item[1] for item in all_items]
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists] merged_hashes = [str(item[2])
for item in all_items] if mm_hashes is not None else None
# Flatten sorted list of lists to a single list and verify there is no
# interleaving of placeholders from different modalities.
merged_placeholders: list[PlaceholderRange] = []
for modality, placeholder_list in sorted_placeholder_lists:
if merged_placeholders and placeholder_list[0][
'offset'] < merged_placeholders[-1]['offset']:
raise ValueError(
"Interleaved mixed-modality inference is currently not "
"supported.")
merged_placeholders.extend(placeholder_list)
if sorted_hash_lists is not None:
merged_hashes = []
for hash_list in sorted_hash_lists:
merged_hashes.extend(hash_list)
else:
merged_hashes = None
return sorted_modalities, merged_placeholders, merged_hashes return sorted_modalities, merged_placeholders, merged_hashes
...@@ -504,8 +364,7 @@ def group_mm_inputs_by_modality( ...@@ -504,8 +364,7 @@ def group_mm_inputs_by_modality(
Returns: Returns:
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
inner list contains consecutive MultiModalKwargs with same modality, or inner list contains consecutive MultiModalKwargs with same modality.
one with multimodal modalities.
""" """
if not mm_inputs: if not mm_inputs:
return [] return []
......
...@@ -13,7 +13,7 @@ from PIL import Image ...@@ -13,7 +13,7 @@ from PIL import Image
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_video_processor from vllm.transformers_utils.processor import cached_get_video_processor
from vllm.utils import PlaceholderModule, is_list_of from vllm.utils import is_list_of
from .base import MediaIO, ModalityData from .base import MediaIO, ModalityData
from .image import ImageMediaIO, ImagePlugin from .image import ImageMediaIO, ImagePlugin
...@@ -22,11 +22,6 @@ from .inputs import MultiModalKwargs, VideoItem ...@@ -22,11 +22,6 @@ from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -117,6 +112,69 @@ def sample_frames_from_video(frames: npt.NDArray, ...@@ -117,6 +112,69 @@ def sample_frames_from_video(frames: npt.NDArray,
return sampled_frames return sampled_frames
class VideoLoader:
@classmethod
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
raise NotImplementedError
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if (abi < 1 or (abi == 1 and api < 2)):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
frame_idx = list(range(0, total_frames_num))
else:
uniform_sampled_frames = np.linspace(0,
total_frames_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(total_frames_num):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_idx: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
# we expect all frames loaded
assert i == num_frames
return frames
class VideoMediaIO(MediaIO[npt.NDArray]): class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__( def __init__(
...@@ -129,22 +187,10 @@ class VideoMediaIO(MediaIO[npt.NDArray]): ...@@ -129,22 +187,10 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
self.image_io = image_io self.image_io = image_io
self.num_frames = num_frames self.num_frames = num_frames
self.video_loader = OpenCVVideoBackend
def load_bytes(self, data: bytes) -> npt.NDArray: def load_bytes(self, data: bytes) -> npt.NDArray:
vr = decord.VideoReader(BytesIO(data), num_threads=1) return self.video_loader.load_bytes(data, self.num_frames)
total_frame_num = len(vr)
num_frames = self.num_frames
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = list(range(0, total_frame_num))
return vr.get_batch(frame_idx).asnumpy()
def load_base64(self, media_type: str, data: str) -> npt.NDArray: def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg": if media_type.lower() == "video/jpeg":
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import sys
from importlib.util import find_spec
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import psutil import psutil
...@@ -41,6 +43,9 @@ class CpuPlatform(Platform): ...@@ -41,6 +43,9 @@ class CpuPlatform(Platform):
use_mla: bool) -> str: use_mla: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
logger.info("Using CPU MLA backend.")
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
...@@ -68,8 +73,15 @@ class CpuPlatform(Platform): ...@@ -68,8 +73,15 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
ipex_avaliable = find_spec("intel_extension_for_pytorch") is not None
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 cache_config.block_size = 128 if ipex_avaliable else 16
if not ipex_avaliable and cache_config.block_size != 16:
raise RuntimeError(
f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch")
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if ((scheduler_config.chunked_prefill_enabled if ((scheduler_config.chunked_prefill_enabled
...@@ -133,9 +145,6 @@ class CpuPlatform(Platform): ...@@ -133,9 +145,6 @@ class CpuPlatform(Platform):
# Disable torch async compiling which won't work with daemonic processes # Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# MLA attention is not supported
os.environ["VLLM_MLA_DISABLE"] = "1"
# Intel OpenMP setting # Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "") ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str: if "libiomp5.so" in ld_prealod_str:
...@@ -152,6 +161,13 @@ class CpuPlatform(Platform): ...@@ -152,6 +161,13 @@ class CpuPlatform(Platform):
# To hint IPEX uses shared memory based AllReduce # To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str( os.environ["LOCAL_WORLD_SIZE"] = str(
vllm_config.parallel_config.tensor_parallel_size) vllm_config.parallel_config.tensor_parallel_size)
if sys.platform == "darwin" and \
envs.VLLM_WORKER_MULTIPROC_METHOD == "fork":
if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None:
logger.warning(
"Default to spawn method on MacOS. If this is not desired,"
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
@classmethod @classmethod
def is_pin_memory_available(cls) -> bool: def is_pin_memory_available(cls) -> bool:
......
...@@ -20,8 +20,9 @@ from vllm.utils import import_pynvml ...@@ -20,8 +20,9 @@ from vllm.utils import import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import ModelConfig, VllmConfig
else: else:
ModelConfig = None
VllmConfig = None VllmConfig = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -100,7 +101,7 @@ class CudaPlatformBase(Platform): ...@@ -100,7 +101,7 @@ class CudaPlatformBase(Platform):
return True return True
@classmethod @classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool: def is_fully_connected(cls, device_ids: List[int]) -> bool:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
...@@ -303,6 +304,14 @@ class CudaPlatformBase(Platform): ...@@ -303,6 +304,14 @@ class CudaPlatformBase(Platform):
def supports_fp8(cls) -> bool: def supports_fp8(cls) -> bool:
return cls.has_device_capability(89) return cls.has_device_capability(89)
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
return True
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
...@@ -357,7 +366,7 @@ class NvmlCudaPlatform(CudaPlatformBase): ...@@ -357,7 +366,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@classmethod @classmethod
@with_nvml_context @with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
""" """
query if the set of gpus are fully connected by nvlink (1 hop) query if the set of gpus are fully connected by nvlink (1 hop)
""" """
...@@ -422,7 +431,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase): ...@@ -422,7 +431,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
return device_props.total_memory return device_props.total_memory
@classmethod @classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
logger.exception( logger.exception(
"NVLink detection not possible, as context support was" "NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.") " not found. Assuming no NVLink available.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment