"...git@developer.sourcefind.cn:2222/renzhc/diffusers_dcu.git" did not exist on "85f9d92883d5a948e3cecb0e62a80780d897e3e1"
Commit fb35feea authored by zhuwenwen's avatar zhuwenwen
Browse files

[Frontend] Require flag for loading text and image embeds

parent 9e94b9d8
...@@ -26,6 +26,13 @@ class MultiModalConfig: ...@@ -26,6 +26,13 @@ class MultiModalConfig:
For example, to allow up to 16 images and 2 videos per prompt: For example, to allow up to 16 images and 2 videos per prompt:
`{"image": 16, "video": 2}`""" `{"image": 16, "video": 2}`"""
enable_mm_embeds: bool = False
"""If `True`, enables passing multimodal embeddings:
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities. """Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set For example, to set num_frames for video, set
......
...@@ -379,6 +379,7 @@ class EngineArgs: ...@@ -379,6 +379,7 @@ class EngineArgs:
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
limit_mm_per_prompt: dict[str, int] = \ limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt") get_field(MultiModalConfig, "limit_per_prompt")
enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str, media_io_kwargs: dict[str, dict[str,
Any]] = get_field(MultiModalConfig, Any]] = get_field(MultiModalConfig,
...@@ -796,6 +797,9 @@ class EngineArgs: ...@@ -796,6 +797,9 @@ class EngineArgs:
) )
multimodal_group.add_argument("--limit-mm-per-prompt", multimodal_group.add_argument("--limit-mm-per-prompt",
**multimodal_kwargs["limit_per_prompt"]) **multimodal_kwargs["limit_per_prompt"])
multimodal_group.add_argument(
"--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
)
multimodal_group.add_argument("--media-io-kwargs", multimodal_group.add_argument("--media-io-kwargs",
**multimodal_kwargs["media_io_kwargs"]) **multimodal_kwargs["media_io_kwargs"])
multimodal_group.add_argument( multimodal_group.add_argument(
...@@ -1034,6 +1038,7 @@ class EngineArgs: ...@@ -1034,6 +1038,7 @@ class EngineArgs:
enable_prompt_embeds=self.enable_prompt_embeds, enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling, skip_mm_profiling=self.skip_mm_profiling,
......
...@@ -844,6 +844,10 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -844,6 +844,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image( def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None self, image_url: Optional[str], uuid: Optional[str] = None
...@@ -858,6 +862,12 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -858,6 +862,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image_embeds: Union[str, dict[str, str], None], image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None, uuid: Optional[str] = None,
) -> None: ) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
embeds = { embeds = {
k: self._connector.fetch_image_embedding(v) k: self._connector.fetch_image_embedding(v)
...@@ -929,6 +939,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -929,6 +939,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image( def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None self, image_url: Optional[str], uuid: Optional[str] = None
...@@ -945,6 +959,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -945,6 +959,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_embeds: Union[str, dict[str, str], None], image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None, uuid: Optional[str] = None,
) -> None: ) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
future: asyncio.Future[Union[str, dict[str, str], None]] = ( future: asyncio.Future[Union[str, dict[str, str], None]] = (
asyncio.Future() asyncio.Future()
) )
......
...@@ -135,14 +135,17 @@ class BaseRenderer(ABC): ...@@ -135,14 +135,17 @@ class BaseRenderer(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def load_prompt_embeds( def load_prompt_embeds(
cls, self,
prompt_embeds: Union[bytes, list[bytes]], prompt_embeds: Union[bytes, list[bytes]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> list[EngineEmbedsPrompt]: ) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects.""" """Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load( tensor = torch.load(
......
...@@ -1296,6 +1296,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1296,6 +1296,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items(): for modality, items in mm_items.items():
self.validate_num_items(modality, len(items)) self.validate_num_items(modality, len(items))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment