Unverified Commit f5d0f478 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Improve error message for too many mm items (#22114)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b690e348
...@@ -579,10 +579,7 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ...@@ -579,10 +579,7 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message="coroutine 'async_get_and_parse_image' was never awaited") message="coroutine 'async_get_and_parse_image' was never awaited")
with pytest.raises( with pytest.raises(ValueError, match="At most"):
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
parse_chat_messages( parse_chat_messages(
[{ [{
"role": "role":
...@@ -622,10 +619,7 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ...@@ -622,10 +619,7 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message="coroutine 'async_get_and_parse_image' was never awaited") message="coroutine 'async_get_and_parse_image' was never awaited")
with pytest.raises( with pytest.raises(ValueError, match="At most"):
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
parse_chat_messages( parse_chat_messages(
[{ [{
"role": "role":
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional, cast from typing import Optional, cast
from unittest.mock import MagicMock
import numpy as np import numpy as np
import pytest import pytest
...@@ -957,15 +956,14 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -957,15 +956,14 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
) )
processor = MULTIMODAL_REGISTRY.create_processor(model_config) processor = MULTIMODAL_REGISTRY.create_processor(model_config)
profiler = MultiModalProfiler(processor) processor._supported_mm_limits = {"image": num_supported}
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) profiler = MultiModalProfiler(processor)
processor.info.get_supported_mm_limits = mock_supported_mm_limits
if is_valid: if is_valid:
exc_ctx = nullcontext() exc_ctx = nullcontext()
else: else:
exc_ctx = pytest.raises(ValueError, match="The model only supports") exc_ctx = pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
profiler.get_decoder_dummy_data( profiler.get_decoder_dummy_data(
...@@ -1002,7 +1000,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ...@@ -1002,7 +1000,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
if is_valid: if is_valid:
exc_ctx = nullcontext() exc_ctx = nullcontext()
else: else:
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image") exc_ctx = pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
processor.apply( processor.apply(
......
...@@ -535,9 +535,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -535,9 +535,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self._model_config return self._model_config
@cached_property @cached_property
def model_cls(self): def model_cls(self) -> type[SupportsMultiModal]:
from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.model_loader import get_model_cls
return get_model_cls(self.model_config) model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls)
@property @property
def allowed_local_media_path(self): def allowed_local_media_path(self):
...@@ -547,31 +548,23 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -547,31 +548,23 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def mm_registry(self): def mm_registry(self):
return MULTIMODAL_REGISTRY return MULTIMODAL_REGISTRY
@cached_property
def mm_processor(self):
return self.mm_registry.create_processor(self.model_config)
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
""" """
mm_registry = self.mm_registry
model_config = self.model_config
model_cls = cast(SupportsMultiModal, self.model_cls)
input_modality = modality.replace("_embeds", "") input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1
mm_processor = mm_registry.create_processor(model_config) self.mm_processor.validate_num_items(input_modality, num_items)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self._items_by_modality[modality].append(item) self._items_by_modality[modality].append(item)
return model_cls.get_placeholder_str(modality, current_count) return self.model_cls.get_placeholder_str(modality, num_items)
@abstractmethod @abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
...@@ -1156,6 +1155,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1156,6 +1155,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self.data_parser = self._get_data_parser() self.data_parser = self._get_data_parser()
# Avoid unnecessary recomputation
self._supported_mm_limits = self.info.get_supported_mm_limits()
self._allowed_mm_limits = self.info.get_allowed_mm_limits()
@property
def supported_mm_limits(self):
return self._supported_mm_limits
@property
def allowed_mm_limits(self):
return self._allowed_mm_limits
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
...@@ -1176,6 +1187,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1176,6 +1187,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
return MultiModalDataParser() return MultiModalDataParser()
def validate_num_items(
self,
modality: str,
num_items: int,
) -> None:
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = (f"At most {limit} {modality}(s) may be provided in "
"one prompt.")
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def _to_mm_items( def _to_mm_items(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
...@@ -1188,26 +1221,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1188,26 +1221,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
supported_mm_limits = self.info.get_supported_mm_limits()
allowed_mm_limits = self.info.get_allowed_mm_limits()
for modality, items in mm_items.items(): for modality, items in mm_items.items():
supported_limit = supported_mm_limits.get(modality, 0) self.validate_num_items(modality, len(items))
allowed_limit = allowed_mm_limits.get(modality, 0)
num_items = len(items)
if supported_limit is not None and num_items > supported_limit:
raise ValueError(
f"The model only supports at most {supported_limit} "
f"{modality} items, but you passed {num_items} "
f"{modality} items in the same prompt.")
if num_items > allowed_limit:
raise ValueError(
"You set or defaulted to "
f"'{json.dumps({modality: allowed_limit})}' in "
f"`--limit-mm-per-prompt`, but passed {num_items} "
f"{modality} items in the same prompt.")
return mm_items return mm_items
......
...@@ -156,7 +156,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -156,7 +156,7 @@ class MultiModalProfiler(Generic[_I]):
return self.processor.dummy_inputs return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]: def get_mm_limits(self) -> Mapping[str, int]:
return self.processing_info.get_allowed_mm_limits() return self.processor.allowed_mm_limits
def _get_dummy_mm_inputs( def _get_dummy_mm_inputs(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment