Unverified Commit 6c0b7f54 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Core][VLM] Add precise multi-modal placeholder tracking (#8346)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent d151fde8
......@@ -11,8 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -30,6 +30,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once
......@@ -73,7 +74,11 @@ def dummy_seq_data_for_chameleon(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_chameleon(
......@@ -97,14 +102,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_chameleon(
seq_data, ranges = dummy_seq_data_for_chameleon(
seq_len,
num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
)
mm_data = dummy_image_for_chameleon(num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
def input_processor_for_chameleon(ctx: InputContext,
......@@ -120,9 +125,14 @@ def input_processor_for_chameleon(ctx: InputContext,
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
......
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
......@@ -49,14 +50,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image"):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
......@@ -65,7 +65,11 @@ def dummy_seq_data_for_clip(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_clip(
......@@ -117,6 +121,11 @@ def input_processor_for_clip(
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
......@@ -130,7 +139,7 @@ def input_processor_for_clip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
......@@ -141,7 +150,8 @@ def input_processor_for_clip(
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
......
......@@ -27,8 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
......@@ -37,9 +37,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
......@@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu(
......@@ -119,15 +125,15 @@ def dummy_image_for_fuyu(
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: Image.Image):
data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1)
......@@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
model_config = ctx.model_config
image_data = multi_modal_data["image"]
new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
# process image data
if isinstance(image_data, Image.Image):
if is_list_of(image_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
......@@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
])
new_multi_modal_data["image"] = image_patches
elif isinstance(image_data, torch.Tensor):
elif is_list_of(image_list, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
......@@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config
if isinstance(data, Image.Image):
data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data)
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]
......
......@@ -17,8 +17,8 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -379,7 +379,7 @@ class InternVLInputPipeline:
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
......@@ -398,7 +398,7 @@ class InternVLInputPipeline:
image_height_override=max_image_height,
)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
......
......@@ -10,7 +10,8 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
......@@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
......@@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
......@@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......
......@@ -12,7 +12,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -180,7 +181,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
......@@ -195,9 +196,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
image_height_override=max_feat_height,
)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
......@@ -212,7 +213,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
image_height_override=max_feat_height,
)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......
......@@ -11,8 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -108,33 +108,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
)
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
)
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......@@ -145,6 +147,12 @@ def input_processor_for_llava_next_video(ctx: InputContext,
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
......@@ -160,7 +168,7 @@ def input_processor_for_llava_next_video(ctx: InputContext,
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
......@@ -170,7 +178,8 @@ def input_processor_for_llava_next_video(ctx: InputContext,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
......
......@@ -15,8 +15,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -218,31 +218,31 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
mm_key="video")
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
mm_key="video")
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
......@@ -320,7 +320,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
......@@ -330,7 +330,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
video_feature_size = []
......
......@@ -36,8 +36,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
......@@ -277,7 +277,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data)
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
......
......@@ -36,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
EncoderDecoderInputs, InputContext)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -176,13 +176,14 @@ def dummy_image(num_images: int, ):
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return dummy_decoder_seq_data(seq_len, num_images), None
return DummyData(dummy_decoder_seq_data(seq_len, num_images))
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images)
return DummyData(dummy_encoder_seq_data(ctx, num_images),
dummy_image(num_images))
def _prepare_aspect_ratio_attention_mask(
......
......@@ -7,8 +7,8 @@ from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
......@@ -58,7 +58,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_siglip(
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
......@@ -66,7 +66,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
def input_processor_for_paligemma(ctx: InputContext,
......
......@@ -28,8 +28,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig,
PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -380,7 +380,7 @@ def dummy_data_for_phi3v(ctx: InputContext,
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
num_images,
......@@ -394,7 +394,7 @@ def dummy_data_for_phi3v(ctx: InputContext,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
@lru_cache
......
......@@ -17,8 +17,8 @@ from transformers.models.pixtral.modeling_pixtral import (
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -28,7 +28,8 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
......@@ -81,7 +82,12 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
)
mm_data = {"image": num_images * [image]}
return seq_data, mm_data
mm_placeholders = {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
return DummyData(seq_data, mm_data, mm_placeholders)
def input_mapper_for_pixtral(ctx: InputContext,
......@@ -630,13 +636,13 @@ def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
def dummy_seq_data_for_pixtral_hf(
hf_config: PixtralVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
hf_config: PixtralVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image"):
if image_feature_size_override is None:
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
else:
......@@ -645,7 +651,11 @@ def dummy_seq_data_for_pixtral_hf(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_pixtral_hf(
......
......@@ -23,8 +23,8 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -810,7 +810,7 @@ def dummy_data_for_qwen(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, Optional[Dict]]:
) -> DummyData:
"""Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config.
......@@ -829,7 +829,7 @@ def dummy_data_for_qwen(
if not hasattr(hf_config, "visual"):
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
mm_data = None
return seq_data, mm_data
return DummyData(seq_data, mm_data)
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
......@@ -861,7 +861,7 @@ def dummy_data_for_qwen(
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return seq_data, mm_data
return DummyData(seq_data, mm_data)
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
......
......@@ -31,8 +31,8 @@ from transformers import Qwen2AudioConfig, Qwen2AudioEncoder
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
......@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP
......@@ -85,7 +86,8 @@ class Qwen2AudioMultiModalProjector(nn.Module):
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"]
max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios
max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
max_llm_audio_tokens = max_tokens_per_audio * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
......@@ -99,7 +101,12 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
(0, seq_len - max_llm_audio_tokens),
)
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}
return DummyData(
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
"audio":
consecutive_placeholder_ranges(num_items=num_audios,
item_size=max_tokens_per_audio)
})
def get_processor(
......
......@@ -44,8 +44,8 @@ from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
......@@ -744,9 +744,10 @@ def dummy_data_for_qwen2_vl(
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)
return dummy_seqdata, {
"image": dummy_image if num_images == 1 else [dummy_image] * num_images
}
return DummyData(dummy_seqdata, {
"image":
dummy_image if num_images == 1 else [dummy_image] * num_images
})
def _get_llm_num_vision_tokens(
......
......@@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
......@@ -61,6 +62,7 @@ def dummy_seq_data_for_siglip(
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image",
):
if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config)
......@@ -70,7 +72,11 @@ def dummy_seq_data_for_siglip(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_siglip(
......@@ -122,6 +128,11 @@ def input_processor_for_siglip(
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
......@@ -135,7 +146,7 @@ def input_processor_for_siglip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
......@@ -144,11 +155,10 @@ def input_processor_for_siglip(
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
......
......@@ -2,7 +2,6 @@
"""PyTorch Ultravox model."""
import math
from array import array
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast)
......@@ -17,27 +16,27 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import DecoderOnlyInputs, token_inputs
from vllm.inputs.registry import InputContext
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings)
init_vllm_registered_model,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
......@@ -46,13 +45,13 @@ _AUDIO_TOKENS_PER_SECOND = 6.25
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)"""
"""Shape: `(batch_size, num_audios, 80, M)`"""
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
......@@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox(
seq_len: int,
audio_count: int,
):
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
# Add a separator between each chunk.
audio_token_ids = (audio_placeholder +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
return SequenceData(audio_token_ids + other_token_ids)
return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}
def dummy_audio_for_ultravox(
......@@ -107,10 +105,10 @@ def dummy_data_for_ultravox(
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return (seq_data, mm_dict)
return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
......@@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
......@@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
......@@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
class StackAudioFrames(nn.Module):
......@@ -472,9 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
merge_multimodal_embeddings_from_map(
inputs_embeds, audio_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
input_ids = None
else:
inputs_embeds = None
......
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
......@@ -326,6 +326,22 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
This updates ``inputs_embeds`` in place.
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src]
return inputs_embeds
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
......
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors)
MultiModalDataDict, MultiModalInputs,
MultiModalPlaceholderDict, MultiModalPlaceholderMap,
MultiModalPlugin, NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
......@@ -17,6 +18,8 @@ __all__ = [
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalInputs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalPlugin",
"NestedTensors",
"MULTIMODAL_REGISTRY",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment