Unverified Commit c4df59ad authored by Reagan Lee's avatar Reagan Lee Committed by GitHub
Browse files

Add embedding input functionality for disabled modalities [remake] (#32493)


Signed-off-by: default avatarReagan Lee <“reaganjlee@gmail.com”>
Signed-off-by: default avatarReagan Lee <reaganjlee@gmail.com>
Signed-off-by: default avatarReagan Lee <96998476+reaganjlee@users.noreply.github.com>
Co-authored-by: default avatarReagan Lee <“reaganjlee@gmail.com”>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 785cf28f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
import pytest
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.distributed import cleanup_dist_env_and_memory
MODEL = "llava-hf/llava-1.5-7b-hf"
PROMPT = "USER: <image>\nDescribe this image briefly.\nASSISTANT:"
TEXT_ONLY_PROMPT = "USER: What is 2 + 2?\nASSISTANT:"
@pytest.fixture(scope="module")
def llm():
"""LLM with enable_mm_embeds=True and all modality limits zeroed out."""
llm = LLM(
model=MODEL,
max_model_len=2048,
enforce_eager=True,
gpu_memory_utilization=0.8,
enable_mm_embeds=True,
limit_mm_per_prompt={"image": 0},
)
yield weakref.proxy(llm)
del llm
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_generate_with_embedding(llm: LLM):
"""Pre-computed embedding produces tokens without hanging."""
embedding = ImageAsset("stop_sign").image_embeds
outputs = llm.generate(
{"prompt": PROMPT, "multi_modal_data": {"image": embedding}},
sampling_params=SamplingParams(max_tokens=32, temperature=0.0),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
@pytest.mark.skip_global_cleanup
def test_raw_image_rejected(llm: LLM):
"""Raw image input is still rejected when limit=0."""
raw_image = ImageAsset("stop_sign").pil_image
with pytest.raises(ValueError, match=r"At most 0 image\(s\)"):
llm.generate(
{"prompt": PROMPT, "multi_modal_data": {"image": raw_image}},
sampling_params=SamplingParams(max_tokens=16),
)
@pytest.mark.skip_global_cleanup
def test_text_only_prompt(llm: LLM):
"""Text-only prompts still work under this config."""
outputs = llm.generate(
TEXT_ONLY_PROMPT,
sampling_params=SamplingParams(max_tokens=16, temperature=0.0),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
...@@ -901,7 +901,7 @@ def test_find_mm_placeholders( ...@@ -901,7 +901,7 @@ def test_find_mm_placeholders(
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("limit", "num_supported", "is_valid"), ("num_images", "limit", "is_valid"),
[ [
(0, 0, True), (0, 0, True),
(0, 1, True), (0, 1, True),
...@@ -912,7 +912,7 @@ def test_find_mm_placeholders( ...@@ -912,7 +912,7 @@ def test_find_mm_placeholders(
(2, 2, True), (2, 2, True),
], ],
) )
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
limit_mm_per_prompt = {"image": limit} limit_mm_per_prompt = {"image": limit}
model_config = ModelConfig( model_config = ModelConfig(
...@@ -921,33 +921,46 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -921,33 +921,46 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
) )
processor = MULTIMODAL_REGISTRY.create_processor(model_config) processor = MULTIMODAL_REGISTRY.create_processor(model_config)
processor.info.get_supported_mm_limits = lambda: {"image": num_supported}
rng = np.random.RandomState(0)
image = random_image(rng, min_wh=128, max_wh=256)
if num_images == 0:
mm_data = {}
elif num_images == 1:
mm_data = {"image": image}
else:
mm_data = {"image": [image] * num_images}
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
MULTIMODAL_REGISTRY.get_dummy_mm_inputs( processor.apply(
model_config, "<image>" * num_images,
mm_counts=limit_mm_per_prompt, mm_items=processor.info.parse_mm_data(mm_data),
processor=processor, hf_processor_mm_kwargs={},
) )
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("num_images", "limit", "is_valid"), ("user_limit", "supported_limit"),
[ [
(0, 0, True), (0, 0),
(0, 1, True), (0, 1),
(1, 0, False), (1, 0), # user wants 1, model supports 0 → capped to 0
(1, 1, True), (1, 1),
(1, 2, True), (1, 2),
(2, 1, False), (2, 1), # user wants 2, model supports 1 → capped to 1
(2, 2, True), (2, 2),
(5, 1), # large user limit, low model support → capped to 1
(1, 5),
(10, 0), # large user limit, no model support → capped to 0
], ],
) )
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): def test_budget_caps_prevent_dummy_input_validation_failure(
limit_mm_per_prompt = {"image": limit} model_id, user_limit, supported_limit
):
limit_mm_per_prompt = {"image": user_limit}
model_config = ModelConfig( model_config = ModelConfig(
model=model_id, model=model_id,
...@@ -955,24 +968,21 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ...@@ -955,24 +968,21 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
) )
processor = MULTIMODAL_REGISTRY.create_processor(model_config) processor = MULTIMODAL_REGISTRY.create_processor(model_config)
processor.info.get_supported_mm_limits = lambda: {"image": supported_limit}
rng = np.random.RandomState(0) # This is what budget.py uses to derive mm_counts
image = random_image(rng, min_wh=128, max_wh=256) allowed = processor.info.allowed_mm_limits
if num_images == 0:
mm_data = {}
elif num_images == 1:
mm_data = {"image": image}
else:
mm_data = {"image": [image] * num_images}
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") assert allowed["image"] <= supported_limit, (
f"allowed_mm_limits['image']={allowed['image']} exceeds "
f"supported_limit={supported_limit}"
)
with exc_ctx: assert allowed["image"] <= user_limit, (
processor.apply( f"allowed_mm_limits['image']={allowed['image']} exceeds user_limit={user_limit}"
"<image>" * num_images, )
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, assert allowed["image"] == min(user_limit, supported_limit)
)
class DummyProcessor: class DummyProcessor:
......
...@@ -76,6 +76,11 @@ class MultiModalConfig: ...@@ -76,6 +76,11 @@ class MultiModalConfig:
for the OpenAI-compatible server, this refers to chat messages with content for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`. `"type": "*_embeds"`.
When enabled with `--limit-mm-per-prompt` set to 0 for a modality,
precomputed embeddings skip count validation for that modality,
saving memory by not loading encoder modules while still enabling
embeddings as an input. Limits greater than 0 still apply to embeddings.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!""" Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
......
...@@ -528,7 +528,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -528,7 +528,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else: else:
num_items = len(self._items_by_modality[original_modality]) + 1 num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.info.validate_num_items(input_modality, num_items) mm_config = self.model_config.multimodal_config
if (
mm_config is not None
and mm_config.enable_mm_embeds
and mm_config.get_limit_per_prompt(input_modality) == 0
and original_modality.endswith("_embeds")
):
# Skip validation: embeddings bypass limit when enable_mm_embeds=True
pass
else:
self.mm_processor.info.validate_num_items(input_modality, num_items)
# Track original modality for vision_chunk items # Track original modality for vision_chunk items
if use_vision_chunk: if use_vision_chunk:
......
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
from collections.abc import Mapping from collections.abc import Mapping
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
logger = init_logger(__name__)
def get_mm_max_toks_per_item( def get_mm_max_toks_per_item(
model_config: ModelConfig, model_config: ModelConfig,
...@@ -59,11 +62,26 @@ class MultiModalBudget: ...@@ -59,11 +62,26 @@ class MultiModalBudget:
processor = mm_registry.create_processor(model_config, cache=cache) processor = mm_registry.create_processor(model_config, cache=cache)
self.cache = cache self.cache = cache
mm_config = model_config.get_multimodal_config()
enable_mm_embeds = mm_config is not None and mm_config.enable_mm_embeds
supported_mm_limits = processor.info.supported_mm_limits
self.mm_limits = mm_limits = processor.info.allowed_mm_limits self.mm_limits = mm_limits = processor.info.allowed_mm_limits
active_modalities = { # Modalities that pass through the MM encoder tower
modality for modality, limit in mm_limits.items() if limit > 0 tower_modalities = {
modality
for modality in supported_mm_limits
if mm_limits.get(modality, 0) > 0
} }
# Modalities that bypass the tower (pre-computed embeddings only)
embed_only_modalities = {
modality
for modality in supported_mm_limits
if enable_mm_embeds and mm_limits.get(modality, 0) == 0
}
active_modalities = tower_modalities | embed_only_modalities
all_mm_max_toks_per_item = get_mm_max_toks_per_item( all_mm_max_toks_per_item = get_mm_max_toks_per_item(
model_config, model_config,
...@@ -72,19 +90,32 @@ class MultiModalBudget: ...@@ -72,19 +90,32 @@ class MultiModalBudget:
mm_counts=dict.fromkeys(active_modalities, 1), mm_counts=dict.fromkeys(active_modalities, 1),
) )
if embed_only_modalities:
logger.info_once(
"enable_mm_embeds is True; modalities handled as embedding-only: %s",
tuple(embed_only_modalities),
)
# Some models (e.g., Qwen3Omni with use_audio_in_video=True) share # Some models (e.g., Qwen3Omni with use_audio_in_video=True) share
# placeholders between modalities, so not all active modalities will # placeholders between modalities, so not all active modalities will
# have their own entry in the returned dict. We filter to only include # have their own entry in the returned dict. We filter to only include
# modalities that have independent placeholder tokens. # modalities that have independent placeholder tokens.
mm_max_toks_per_item = { active_mm_max_toks_per_item = {
modality: all_mm_max_toks_per_item[modality] modality: all_mm_max_toks_per_item[modality]
for modality in active_modalities for modality in active_modalities
if modality in all_mm_max_toks_per_item if modality in all_mm_max_toks_per_item
} }
tower_mm_max_toks_per_item = {
modality: active_mm_max_toks_per_item[modality]
for modality in tower_modalities
if modality in active_mm_max_toks_per_item
}
# Encoder budget is computed from all active modalities (including
# embedding-only ones that need encoder cache space).
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config, scheduler_config,
mm_max_toks_per_item, active_mm_max_toks_per_item,
) )
self.encoder_compute_budget = encoder_compute_budget self.encoder_compute_budget = encoder_compute_budget
...@@ -93,13 +124,15 @@ class MultiModalBudget: ...@@ -93,13 +124,15 @@ class MultiModalBudget:
mm_max_items_per_prompt = dict[str, int]() mm_max_items_per_prompt = dict[str, int]()
mm_max_items_per_batch = dict[str, int]() mm_max_items_per_batch = dict[str, int]()
for modality, max_toks_per_item in mm_max_toks_per_item.items(): # Per-prompt/per-batch limits are only relevant for tower modalities
# (embedding-only modalities don't go through the encoder tower).
for modality, max_toks_per_item in tower_mm_max_toks_per_item.items():
( (
mm_max_items_per_prompt[modality], mm_max_items_per_prompt[modality],
mm_max_items_per_batch[modality], mm_max_items_per_batch[modality],
) = self._get_max_items(modality, max_toks_per_item) ) = self._get_max_items(modality, max_toks_per_item)
self.mm_max_toks_per_item = mm_max_toks_per_item self.mm_max_toks_per_item = tower_mm_max_toks_per_item
self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt
self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch
......
...@@ -681,16 +681,22 @@ class BaseProcessingInfo: ...@@ -681,16 +681,22 @@ class BaseProcessingInfo:
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
if validate: if validate:
mm_config = self.ctx.model_config.get_multimodal_config() mm_config = self.ctx.get_mm_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items(): for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
if not mm_config.enable_mm_embeds:
raise ValueError( raise ValueError(
f"You must set `--enable-mm-embeds` to input " f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`" f"`{modality}_embeds`"
) )
if mm_config.get_limit_per_prompt(modality) == 0:
for modality, items in mm_items.items(): logger.debug(
"Skipping count validation for modality "
"'%s' (embeddings with limit=0)",
modality,
)
continue
self.validate_num_items(modality, len(items)) self.validate_num_items(modality, len(items))
return mm_items return mm_items
......
...@@ -95,7 +95,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -95,7 +95,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
""" """
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data) dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
tokenization_kwargs = {"truncation": False} tokenization_kwargs = {"truncation": False}
......
...@@ -1395,7 +1395,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1395,7 +1395,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
missing_modality_data.append(data) missing_modality_data.append(data)
mm_missing_data[modality] = missing_modality_data mm_missing_data[modality] = missing_modality_data
mm_missing_items = self.info.parse_mm_data(mm_missing_data) mm_missing_items = self.info.parse_mm_data(mm_missing_data, validate=False)
return mm_is_cached, mm_missing_items return mm_is_cached, mm_missing_items
......
...@@ -138,6 +138,11 @@ class MultiModalRegistry: ...@@ -138,6 +138,11 @@ class MultiModalRegistry:
mm_config.get_limit_per_prompt(modality) == 0 mm_config.get_limit_per_prompt(modality) == 0
for modality in info.supported_mm_limits for modality in info.supported_mm_limits
): ):
# If enable_mm_embeds is True, we still need MM infrastructure
# to process pre-computed embeddings even though encoder won't run
if mm_config.enable_mm_embeds:
return True
logger.info_once( logger.info_once(
"All limits of multimodal modalities supported by the model " "All limits of multimodal modalities supported by the model "
"are set to 0, running in text-only mode." "are set to 0, running in text-only mode."
......
...@@ -1259,6 +1259,9 @@ class GPUModelRunner( ...@@ -1259,6 +1259,9 @@ class GPUModelRunner(
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None
if not mm_budget.mm_max_toks_per_item:
return {} # No tower modalities (embed-only mode)
dummy_modality = mm_budget.get_modality_with_max_tokens() dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs) return self._get_mm_dummy_batch(dummy_modality, num_seqs)
...@@ -5116,40 +5119,50 @@ class GPUModelRunner( ...@@ -5116,40 +5119,50 @@ class GPUModelRunner(
assert mm_budget is not None assert mm_budget is not None
if (encoder_budget := mm_budget.get_encoder_budget()) > 0: if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text if not mm_budget.mm_max_toks_per_item:
# modality with the max possible input tokens even when # All modality limits are 0 — embedding-only mode.
# it supports multiple. # Budget is non-zero for embedding storage, but
dummy_modality = mm_budget.get_modality_with_max_tokens() # there's no encoder to profile.
max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[ logger.info(
dummy_modality "Skipping encoder profiling for embedding-only "
] "mode (all modality limits=0 with "
"enable_mm_embeds=True).",
)
else:
# NOTE: Currently model is profiled with a single
# non-text modality with the max possible input
# tokens even when it supports multiple.
dummy_modality = mm_budget.get_modality_with_max_tokens()
max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[
dummy_modality
]
logger.info( logger.info(
"Encoder cache will be initialized with a budget of " "Encoder cache will be initialized with a "
"%s tokens, and profiled with %s %s items of the " "budget of %s tokens, and profiled with "
"maximum feature size.", "%s %s items of the maximum feature size.",
encoder_budget, encoder_budget,
max_mm_items_per_batch, max_mm_items_per_batch,
dummy_modality, dummy_modality,
) )
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch( batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality, dummy_modality,
max_mm_items_per_batch, max_mm_items_per_batch,
) )
# Run multimodal encoder. # Run multimodal encoder.
dummy_encoder_outputs = self.model.embed_multimodal( dummy_encoder_outputs = self.model.embed_multimodal(
**batched_dummy_mm_inputs **batched_dummy_mm_inputs
) )
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
dummy_encoder_outputs, dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch, expected_num_items=max_mm_items_per_batch,
) )
for i, output in enumerate(dummy_encoder_outputs): for i, output in enumerate(dummy_encoder_outputs):
self.encoder_cache[f"tmp_{i}"] = output self.encoder_cache[f"tmp_{i}"] = output
# Add `is_profile` here to pre-allocate communication buffers # Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states = self._dummy_run( hidden_states, last_hidden_states = self._dummy_run(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment