Unverified Commit 6c0b7f54 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Core][VLM] Add precise multi-modal placeholder tracking (#8346)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent d151fde8
...@@ -11,8 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig ...@@ -11,8 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -30,6 +30,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -30,6 +30,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -73,7 +74,11 @@ def dummy_seq_data_for_chameleon( ...@@ -73,7 +74,11 @@ def dummy_seq_data_for_chameleon(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_chameleon( def dummy_image_for_chameleon(
...@@ -97,14 +102,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, ...@@ -97,14 +102,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_chameleon( seq_data, ranges = dummy_seq_data_for_chameleon(
seq_len, seq_len,
num_images, num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID, image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
) )
mm_data = dummy_image_for_chameleon(num_images) mm_data = dummy_image_for_chameleon(num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
def input_processor_for_chameleon(ctx: InputContext, def input_processor_for_chameleon(ctx: InputContext,
...@@ -120,9 +125,14 @@ def input_processor_for_chameleon(ctx: InputContext, ...@@ -120,9 +125,14 @@ def input_processor_for_chameleon(ctx: InputContext,
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
...@@ -49,14 +50,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: ...@@ -49,14 +50,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config) return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip( def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
hf_config: CLIPVisionConfig, seq_len: int,
seq_len: int, num_images: int,
num_images: int, *,
*, image_token_id: int,
image_token_id: int, image_feature_size_override: Optional[int] = None,
image_feature_size_override: Optional[int] = None, mm_key: str = "image"):
):
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config) image_feature_size = get_clip_image_feature_size(hf_config)
else: else:
...@@ -65,7 +65,11 @@ def dummy_seq_data_for_clip( ...@@ -65,7 +65,11 @@ def dummy_seq_data_for_clip(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_clip( def dummy_image_for_clip(
...@@ -117,6 +121,11 @@ def input_processor_for_clip( ...@@ -117,6 +121,11 @@ def input_processor_for_clip(
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
...@@ -130,7 +139,7 @@ def input_processor_for_clip( ...@@ -130,7 +139,7 @@ def input_processor_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
...@@ -141,7 +150,8 @@ def input_processor_for_clip( ...@@ -141,7 +150,8 @@ def input_processor_for_clip(
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
......
...@@ -27,8 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor ...@@ -27,8 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -37,9 +37,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -37,9 +37,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
...@@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): ...@@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images) [0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu( def dummy_image_for_fuyu(
...@@ -119,15 +125,15 @@ def dummy_image_for_fuyu( ...@@ -119,15 +125,15 @@ def dummy_image_for_fuyu(
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images, mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: Image.Image): data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt") image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"] batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1) ]).unsqueeze(1)
...@@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
model_config = ctx.model_config model_config = ctx.model_config
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
new_multi_modal_data = {} new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
# process image data # process image data
if isinstance(image_data, Image.Image): if is_list_of(image_list, Image.Image):
# Fuyu's image_processor can also finish token padding # Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor( image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model) model_config.model)
...@@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
]) ])
new_multi_modal_data["image"] = image_patches new_multi_modal_data["image"] = image_patches
elif isinstance(image_data, torch.Tensor): elif is_list_of(image_list, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") raise NotImplementedError("Embeddings input is not supported yet")
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
...@@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
def input_mapper_for_fuyu(ctx: InputContext, data: object): def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, Image.Image): data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding # Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor( image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model) model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data) model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([ data = torch.stack([
image_patch[0] image_patch[0]
for image_patch in model_image_input["image_patches"] for image_patch in model_image_input["image_patches"]
......
...@@ -17,8 +17,8 @@ from transformers import PretrainedConfig ...@@ -17,8 +17,8 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig, from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -379,7 +379,7 @@ class InternVLInputPipeline: ...@@ -379,7 +379,7 @@ class InternVLInputPipeline:
model_config.tokenizer, model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
hf_config.vision_config, hf_config.vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -398,7 +398,7 @@ class InternVLInputPipeline: ...@@ -398,7 +398,7 @@ class InternVLInputPipeline:
image_height_override=max_image_height, image_height_override=max_image_height,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
......
...@@ -10,7 +10,8 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, ...@@ -10,7 +10,8 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ...@@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
image_feature_size = get_max_llava_image_tokens(ctx) image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ...@@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_clip(vision_config, num_images) mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ...@@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_siglip(vision_config, num_images) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, PixtralVisionConfig): elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf( seq_data, ranges = dummy_seq_data_for_pixtral_hf(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ...@@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
......
...@@ -12,7 +12,8 @@ from typing_extensions import NotRequired ...@@ -12,7 +12,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -180,7 +181,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, ...@@ -180,7 +181,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
max_feat_height, max_feat_width = pinpoint max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -195,9 +196,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, ...@@ -195,9 +196,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
image_height_override=max_feat_height, image_height_override=max_feat_height,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -212,7 +213,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, ...@@ -212,7 +213,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
image_height_override=max_feat_height, image_height_override=max_feat_height,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
......
...@@ -11,8 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, ...@@ -11,8 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -108,33 +108,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, ...@@ -108,33 +108,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
video_feature_size = frames_per_video * tokens_per_frame video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
mm_key="video",
) )
pil_frame = dummy_image_for_clip(vision_config, num_images=1) pil_frame = dummy_image_for_clip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"]) np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video} mm_data = {"video": mm_data_per_video}
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
mm_key="video",
) )
pil_frame = dummy_image_for_siglip(vision_config, num_images=1) pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"]) np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video} mm_data = {"video": mm_data_per_video}
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
...@@ -145,6 +147,12 @@ def input_processor_for_llava_next_video(ctx: InputContext, ...@@ -145,6 +147,12 @@ def input_processor_for_llava_next_video(ctx: InputContext,
multi_modal_data = inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data: if multi_modal_data is None or "video" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"] video_data = multi_modal_data["video"]
model_config = ctx.model_config model_config = ctx.model_config
...@@ -160,7 +168,7 @@ def input_processor_for_llava_next_video(ctx: InputContext, ...@@ -160,7 +168,7 @@ def input_processor_for_llava_next_video(ctx: InputContext,
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
...@@ -170,7 +178,8 @@ def input_processor_for_llava_next_video(ctx: InputContext, ...@@ -170,7 +178,8 @@ def input_processor_for_llava_next_video(ctx: InputContext,
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
raise NotImplementedError( raise NotImplementedError(
......
...@@ -15,8 +15,8 @@ from typing_extensions import NotRequired ...@@ -15,8 +15,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -218,31 +218,31 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, ...@@ -218,31 +218,31 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
) mm_key="video")
mm_data = dummy_video_for_clip(vision_config, mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames, num_frames=num_frames,
num_videos=num_videos) num_videos=num_videos)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_videos, num_videos,
image_token_id=hf_config.video_token_index, image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size, image_feature_size_override=video_feature_size,
) mm_key="video")
mm_data = dummy_video_for_siglip(vision_config, mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames, num_frames=num_frames,
num_videos=num_videos) num_videos=num_videos)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
...@@ -320,7 +320,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, ...@@ -320,7 +320,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
...@@ -330,7 +330,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, ...@@ -330,7 +330,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
video_feature_size = [] video_feature_size = []
......
...@@ -36,8 +36,8 @@ from typing_extensions import NotRequired ...@@ -36,8 +36,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
...@@ -277,7 +277,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, ...@@ -277,7 +277,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data)
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
......
...@@ -36,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType ...@@ -36,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
EncoderDecoderInputs, InputContext) EncoderDecoderInputs, InputContext)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -176,13 +176,14 @@ def dummy_image(num_images: int, ): ...@@ -176,13 +176,14 @@ def dummy_image(num_images: int, ):
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
return dummy_decoder_seq_data(seq_len, num_images), None return DummyData(dummy_decoder_seq_data(seq_len, num_images))
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_images = mm_counts["image"] num_images = mm_counts["image"]
return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) return DummyData(dummy_encoder_seq_data(ctx, num_images),
dummy_image(num_images))
def _prepare_aspect_ratio_attention_mask( def _prepare_aspect_ratio_attention_mask(
......
...@@ -7,8 +7,8 @@ from transformers import PaliGemmaConfig ...@@ -7,8 +7,8 @@ from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -58,7 +58,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ...@@ -58,7 +58,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_siglip( seq_data, ranges = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images, num_images,
...@@ -66,7 +66,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ...@@ -66,7 +66,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
) )
mm_data = dummy_image_for_siglip(vision_config, num_images) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
def input_processor_for_paligemma(ctx: InputContext, def input_processor_for_paligemma(ctx: InputContext,
......
...@@ -28,8 +28,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig ...@@ -28,8 +28,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig,
PoolerConfig) PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -380,7 +380,7 @@ def dummy_data_for_phi3v(ctx: InputContext, ...@@ -380,7 +380,7 @@ def dummy_data_for_phi3v(ctx: InputContext,
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
seq_data = dummy_seq_data_for_clip( seq_data, ranges = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len, seq_len,
num_images, num_images,
...@@ -394,7 +394,7 @@ def dummy_data_for_phi3v(ctx: InputContext, ...@@ -394,7 +394,7 @@ def dummy_data_for_phi3v(ctx: InputContext,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
@lru_cache @lru_cache
......
...@@ -17,8 +17,8 @@ from transformers.models.pixtral.modeling_pixtral import ( ...@@ -17,8 +17,8 @@ from transformers.models.pixtral.modeling_pixtral import (
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -28,7 +28,8 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings ...@@ -28,7 +28,8 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -81,7 +82,12 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ...@@ -81,7 +82,12 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
) )
mm_data = {"image": num_images * [image]} mm_data = {"image": num_images * [image]}
return seq_data, mm_data mm_placeholders = {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
return DummyData(seq_data, mm_data, mm_placeholders)
def input_mapper_for_pixtral(ctx: InputContext, def input_mapper_for_pixtral(ctx: InputContext,
...@@ -630,13 +636,13 @@ def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: ...@@ -630,13 +636,13 @@ def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
def dummy_seq_data_for_pixtral_hf( def dummy_seq_data_for_pixtral_hf(
hf_config: PixtralVisionConfig, hf_config: PixtralVisionConfig,
seq_len: int, seq_len: int,
num_images: int, num_images: int,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
): mm_key: str = "image"):
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
else: else:
...@@ -645,7 +651,11 @@ def dummy_seq_data_for_pixtral_hf( ...@@ -645,7 +651,11 @@ def dummy_seq_data_for_pixtral_hf(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_pixtral_hf( def dummy_image_for_pixtral_hf(
......
...@@ -23,8 +23,8 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -23,8 +23,8 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -810,7 +810,7 @@ def dummy_data_for_qwen( ...@@ -810,7 +810,7 @@ def dummy_data_for_qwen(
ctx: InputContext, ctx: InputContext,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, Optional[Dict]]: ) -> DummyData:
"""Build dummy data for warming up Qwen models; this will only contain text """Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config. matching the defaults for VLLM unless the model has a visual config.
...@@ -829,7 +829,7 @@ def dummy_data_for_qwen( ...@@ -829,7 +829,7 @@ def dummy_data_for_qwen(
if not hasattr(hf_config, "visual"): if not hasattr(hf_config, "visual"):
seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
mm_data = None mm_data = None
return seq_data, mm_data return DummyData(seq_data, mm_data)
# We have a visual component - use images to warm up # We have a visual component - use images to warm up
num_images = mm_counts["image"] num_images = mm_counts["image"]
...@@ -861,7 +861,7 @@ def dummy_data_for_qwen( ...@@ -861,7 +861,7 @@ def dummy_data_for_qwen(
# the data will get resized and the # of tokens per image is constant # the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0) image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images} mm_data = {"image": image if num_images == 1 else [image] * num_images}
return seq_data, mm_data return DummyData(seq_data, mm_data)
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
......
...@@ -31,8 +31,8 @@ from transformers import Qwen2AudioConfig, Qwen2AudioEncoder ...@@ -31,8 +31,8 @@ from transformers import Qwen2AudioConfig, Qwen2AudioEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
...@@ -85,7 +86,8 @@ class Qwen2AudioMultiModalProjector(nn.Module): ...@@ -85,7 +86,8 @@ class Qwen2AudioMultiModalProjector(nn.Module):
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"] num_audios = mm_counts["audio"]
max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
max_llm_audio_tokens = max_tokens_per_audio * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0: if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError( raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
...@@ -99,7 +101,12 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, ...@@ -99,7 +101,12 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
(0, seq_len - max_llm_audio_tokens), (0, seq_len - max_llm_audio_tokens),
) )
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios} return DummyData(
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
"audio":
consecutive_placeholder_ranges(num_items=num_audios,
item_size=max_tokens_per_audio)
})
def get_processor( def get_processor(
......
...@@ -44,8 +44,8 @@ from vllm.attention.selector import _Backend ...@@ -44,8 +44,8 @@ from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
...@@ -744,9 +744,10 @@ def dummy_data_for_qwen2_vl( ...@@ -744,9 +744,10 @@ def dummy_data_for_qwen2_vl(
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0) color=0)
return dummy_seqdata, { return DummyData(dummy_seqdata, {
"image": dummy_image if num_images == 1 else [dummy_image] * num_images "image":
} dummy_image if num_images == 1 else [dummy_image] * num_images
})
def _get_llm_num_vision_tokens( def _get_llm_num_vision_tokens(
......
...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
...@@ -61,6 +62,7 @@ def dummy_seq_data_for_siglip( ...@@ -61,6 +62,7 @@ def dummy_seq_data_for_siglip(
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
mm_key: str = "image",
): ):
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config) image_feature_size = get_siglip_image_feature_size(hf_config)
...@@ -70,7 +72,11 @@ def dummy_seq_data_for_siglip( ...@@ -70,7 +72,11 @@ def dummy_seq_data_for_siglip(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_siglip( def dummy_image_for_siglip(
...@@ -122,6 +128,11 @@ def input_processor_for_siglip( ...@@ -122,6 +128,11 @@ def input_processor_for_siglip(
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
...@@ -135,7 +146,7 @@ def input_processor_for_siglip( ...@@ -135,7 +146,7 @@ def input_processor_for_siglip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
...@@ -144,11 +155,10 @@ def input_processor_for_siglip( ...@@ -144,11 +155,10 @@ def input_processor_for_siglip(
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs( return token_inputs(prompt_token_ids=new_token_ids,
prompt_token_ids=new_token_ids, prompt=new_prompt,
prompt=new_prompt, multi_modal_data=multi_modal_data,
multi_modal_data=multi_modal_data, multi_modal_placeholders={"image": ranges})
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from array import array
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast) TypedDict, Union, cast)
...@@ -17,27 +16,27 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder ...@@ -17,27 +16,27 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
from vllm.inputs.data import DecoderOnlyInputs, token_inputs InputContext, token_inputs)
from vllm.inputs.registry import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
from vllm.multimodal.base import MultiModalInputs, NestedTensors NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import IntermediateTensors, SequenceData
SequenceData)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings) init_vllm_registered_model,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
...@@ -46,13 +45,13 @@ _AUDIO_TOKENS_PER_SECOND = 6.25 ...@@ -46,13 +45,13 @@ _AUDIO_TOKENS_PER_SECOND = 6.25
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: NestedTensors data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)""" """Shape: `(batch_size, num_audios, 80, M)`"""
class UltravoxAudioEmbeddingInputs(TypedDict): class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
data: NestedTensors data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
...@@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox( ...@@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox(
seq_len: int, seq_len: int,
audio_count: int, audio_count: int,
): ):
audio_placeholder = array( audio_length = min(get_ultravox_max_audio_tokens(ctx),
VLLM_TOKEN_ID_ARRAY_TYPE, seq_len // audio_count)
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
# Add a separator between each chunk. return SequenceData.from_prompt_token_counts(
audio_token_ids = (audio_placeholder + (_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count (0, seq_len - audio_length * audio_count)), {
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, "audio":
[0]) * (seq_len - len(audio_token_ids)) consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
return SequenceData(audio_token_ids + other_token_ids) }
def dummy_audio_for_ultravox( def dummy_audio_for_ultravox(
...@@ -107,10 +105,10 @@ def dummy_data_for_ultravox( ...@@ -107,10 +105,10 @@ def dummy_data_for_ultravox(
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
): ):
audio_count = mm_counts["audio"] audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count) mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return (seq_data, mm_dict) return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object): def input_mapper_for_ultravox(ctx: InputContext, data: object):
...@@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
if multi_modal_data is None or "audio" not in multi_modal_data: if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
feature_extractor = whisper_feature_extractor(ctx) feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"] audios = multi_modal_data["audio"]
if not isinstance(audios, list): if not isinstance(audios, list):
...@@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
...@@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
class StackAudioFrames(nn.Module): class StackAudioFrames(nn.Module):
...@@ -472,9 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -472,9 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( merge_multimodal_embeddings_from_map(
input_ids, inputs_embeds, audio_embeddings, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN) attn_metadata.multi_modal_placeholder_index_maps["audio"])
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -326,6 +326,22 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: ...@@ -326,6 +326,22 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
This updates ``inputs_embeds`` in place.
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src]
return inputs_embeds
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor, is_multimodal: torch.Tensor,
......
from .base import (BatchedTensorInputs, MultiModalDataBuiltins, from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin, MultiModalDataDict, MultiModalInputs,
NestedTensors) MultiModalPlaceholderDict, MultiModalPlaceholderMap,
MultiModalPlugin, NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
...@@ -17,6 +18,8 @@ __all__ = [ ...@@ -17,6 +18,8 @@ __all__ = [
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalInputs", "MultiModalInputs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalPlugin", "MultiModalPlugin",
"NestedTensors", "NestedTensors",
"MULTIMODAL_REGISTRY", "MULTIMODAL_REGISTRY",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment