"csrc/vscode:/vscode.git/clone" did not exist on "541a2ef892720489f770569417bc1bc4436dbb21"
Unverified Commit 58fab50d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Frontend] Require flag for loading text and image embeds (#27204)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent db6f28d8
...@@ -75,6 +75,14 @@ class MultiModalConfig: ...@@ -75,6 +75,14 @@ class MultiModalConfig:
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}} "height": 512}}
""" """
enable_mm_embeds: bool = False
"""If `True`, enables passing multimodal embeddings:
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities. """Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set For example, to set num_frames for video, set
......
...@@ -438,6 +438,7 @@ class EngineArgs: ...@@ -438,6 +438,7 @@ class EngineArgs:
limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
MultiModalConfig, "limit_per_prompt" MultiModalConfig, "limit_per_prompt"
) )
enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str, Any]] = get_field( media_io_kwargs: dict[str, dict[str, Any]] = get_field(
MultiModalConfig, "media_io_kwargs" MultiModalConfig, "media_io_kwargs"
...@@ -896,6 +897,9 @@ class EngineArgs: ...@@ -896,6 +897,9 @@ class EngineArgs:
multimodal_group.add_argument( multimodal_group.add_argument(
"--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
) )
multimodal_group.add_argument(
"--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
)
multimodal_group.add_argument( multimodal_group.add_argument(
"--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
) )
...@@ -1159,6 +1163,7 @@ class EngineArgs: ...@@ -1159,6 +1163,7 @@ class EngineArgs:
enable_prompt_embeds=self.enable_prompt_embeds, enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling, skip_mm_profiling=self.skip_mm_profiling,
......
...@@ -811,6 +811,10 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -811,6 +811,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None image = self._connector.fetch_image(image_url) if image_url else None
...@@ -822,6 +826,12 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -822,6 +826,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image_embeds: str | dict[str, str] | None, image_embeds: str | dict[str, str] | None,
uuid: str | None = None, uuid: str | None = None,
) -> None: ) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
embeds = { embeds = {
k: self._connector.fetch_image_embedding(v) k: self._connector.fetch_image_embedding(v)
...@@ -886,6 +896,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -886,6 +896,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None image_coro = self._connector.fetch_image_async(image_url) if image_url else None
...@@ -897,6 +911,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -897,6 +911,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_embeds: str | dict[str, str] | None, image_embeds: str | dict[str, str] | None,
uuid: str | None = None, uuid: str | None = None,
) -> None: ) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future() future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
......
...@@ -156,14 +156,17 @@ class BaseRenderer(ABC): ...@@ -156,14 +156,17 @@ class BaseRenderer(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def load_prompt_embeds( def load_prompt_embeds(
cls, self,
prompt_embeds: bytes | list[bytes], prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
) -> list[EngineEmbedsPrompt]: ) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects.""" """Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load( tensor = torch.load(
......
...@@ -1308,6 +1308,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1308,6 +1308,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items(): for modality, items in mm_items.items():
self.validate_num_items(modality, len(items)) self.validate_num_items(modality, len(items))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment