Unverified Commit 58fab50d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Frontend] Require flag for loading text and image embeds (#27204)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent db6f28d8
......@@ -75,6 +75,14 @@ class MultiModalConfig:
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}}
"""
enable_mm_embeds: bool = False
"""If `True`, enables passing multimodal embeddings:
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
......
......@@ -438,6 +438,7 @@ class EngineArgs:
limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
MultiModalConfig, "limit_per_prompt"
)
enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
MultiModalConfig, "media_io_kwargs"
......@@ -896,6 +897,9 @@ class EngineArgs:
multimodal_group.add_argument(
"--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
)
multimodal_group.add_argument(
"--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
)
multimodal_group.add_argument(
"--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
)
......@@ -1159,6 +1163,7 @@ class EngineArgs:
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,
......
......@@ -811,6 +811,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None
......@@ -822,6 +826,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
......@@ -886,6 +896,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
......@@ -897,6 +911,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
if isinstance(image_embeds, dict):
......
......@@ -156,14 +156,17 @@ class BaseRenderer(ABC):
"""
raise NotImplementedError
@classmethod
def load_prompt_embeds(
cls,
self,
prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
......
......@@ -1308,6 +1308,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items():
self.validate_num_items(modality, len(items))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment