Commit fb35feea authored by zhuwenwen's avatar zhuwenwen
Browse files

[Frontend] Require flag for loading text and image embeds

parent 9e94b9d8
......@@ -26,6 +26,13 @@ class MultiModalConfig:
For example, to allow up to 16 images and 2 videos per prompt:
`{"image": 16, "video": 2}`"""
enable_mm_embeds: bool = False
"""If `True`, enables passing multimodal embeddings:
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
......
......@@ -379,6 +379,7 @@ class EngineArgs:
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str,
Any]] = get_field(MultiModalConfig,
......@@ -796,6 +797,9 @@ class EngineArgs:
)
multimodal_group.add_argument("--limit-mm-per-prompt",
**multimodal_kwargs["limit_per_prompt"])
multimodal_group.add_argument(
"--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
)
multimodal_group.add_argument("--media-io-kwargs",
**multimodal_kwargs["media_io_kwargs"])
multimodal_group.add_argument(
......@@ -1034,6 +1038,7 @@ class EngineArgs:
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,
......
......@@ -845,6 +845,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None
) -> None:
......@@ -858,6 +862,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
......@@ -930,6 +940,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None
) -> None:
......@@ -945,6 +959,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
future: asyncio.Future[Union[str, dict[str, str], None]] = (
asyncio.Future()
)
......
......@@ -135,14 +135,17 @@ class BaseRenderer(ABC):
"""
raise NotImplementedError
@classmethod
def load_prompt_embeds(
cls,
self,
prompt_embeds: Union[bytes, list[bytes]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
cache_salt: Optional[str] = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
......
......@@ -1296,6 +1296,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items():
self.validate_num_items(modality, len(items))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment