Unverified Commit 64a9c252 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[UX] Add `--language-model-only` for hybrid models (#34120)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent d0d97e29
...@@ -297,6 +297,7 @@ class ModelConfig: ...@@ -297,6 +297,7 @@ class ModelConfig:
multimodal_config: MultiModalConfig | None = None multimodal_config: MultiModalConfig | None = None
"""Configuration for multimodal model. If `None`, this will be inferred """Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`.""" from the architecture of `self.model`."""
language_model_only: InitVar[bool] = False
limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None
enable_mm_embeds: InitVar[bool | None] = None enable_mm_embeds: InitVar[bool | None] = None
media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None
...@@ -411,6 +412,7 @@ class ModelConfig: ...@@ -411,6 +412,7 @@ class ModelConfig:
def __post_init__( def __post_init__(
self, self,
# Multimodal config init vars # Multimodal config init vars
language_model_only: bool,
limit_mm_per_prompt: dict[str, int | dict[str, int]] | None, limit_mm_per_prompt: dict[str, int | dict[str, int]] | None,
enable_mm_embeds: bool | None, enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None, media_io_kwargs: dict[str, dict[str, Any]] | None,
...@@ -576,6 +578,7 @@ class ModelConfig: ...@@ -576,6 +578,7 @@ class ModelConfig:
mm_encoder_tp_mode = "weights" mm_encoder_tp_mode = "weights"
mm_config_kwargs = dict( mm_config_kwargs = dict(
language_model_only=language_model_only,
limit_per_prompt=limit_mm_per_prompt, limit_per_prompt=limit_mm_per_prompt,
enable_mm_embeds=enable_mm_embeds, enable_mm_embeds=enable_mm_embeds,
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
......
...@@ -54,8 +54,12 @@ DummyOptions: TypeAlias = ( ...@@ -54,8 +54,12 @@ DummyOptions: TypeAlias = (
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
language_model_only: bool = False
"""If True, disables all multimodal inputs by setting all modality limits
to 0. Equivalent to setting --limit-mm-per-prompt to 0 for every
modality."""
limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict) limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict)
"""The maximum number of input items and options allowed per """The maximum number of input items and options allowed per
prompt for each modality. prompt for each modality.
Defaults to 999 for each modality. Defaults to 999 for each modality.
...@@ -63,11 +67,11 @@ class MultiModalConfig: ...@@ -63,11 +67,11 @@ class MultiModalConfig:
{"image": 16, "video": 2} {"image": 16, "video": 2}
Configurable format (with options): Configurable format (with options):
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512}, {"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
"image": {"count": 5, "width": 512, "height": 512}} "image": {"count": 5, "width": 512, "height": 512}}
Mixed format (combining both): Mixed format (combining both):
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}} "height": 512}}
""" """
enable_mm_embeds: bool = False enable_mm_embeds: bool = False
...@@ -215,6 +219,7 @@ class MultiModalConfig: ...@@ -215,6 +219,7 @@ class MultiModalConfig:
the final hidden states. the final hidden states.
""" """
factors: list[Any] = [ factors: list[Any] = [
self.language_model_only,
self.mm_encoder_attn_backend.name self.mm_encoder_attn_backend.name
if self.mm_encoder_attn_backend is not None if self.mm_encoder_attn_backend is not None
else None, else None,
...@@ -228,6 +233,9 @@ class MultiModalConfig: ...@@ -228,6 +233,9 @@ class MultiModalConfig:
Get the maximum number of input items allowed per prompt Get the maximum number of input items allowed per prompt
for the given modality (backward compatible). for the given modality (backward compatible).
""" """
if self.language_model_only:
return 0
limit_data = self.limit_per_prompt.get(modality) limit_data = self.limit_per_prompt.get(modality)
if limit_data is None: if limit_data is None:
......
...@@ -454,6 +454,7 @@ class EngineArgs: ...@@ -454,6 +454,7 @@ class EngineArgs:
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
enforce_eager: bool = ModelConfig.enforce_eager enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
language_model_only: bool = MultiModalConfig.language_model_only
limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
MultiModalConfig, "limit_per_prompt" MultiModalConfig, "limit_per_prompt"
) )
...@@ -975,6 +976,9 @@ class EngineArgs: ...@@ -975,6 +976,9 @@ class EngineArgs:
title="MultiModalConfig", title="MultiModalConfig",
description=MultiModalConfig.__doc__, description=MultiModalConfig.__doc__,
) )
multimodal_group.add_argument(
"--language-model-only", **multimodal_kwargs["language_model_only"]
)
multimodal_group.add_argument( multimodal_group.add_argument(
"--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
) )
...@@ -1291,6 +1295,7 @@ class EngineArgs: ...@@ -1291,6 +1295,7 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds, enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
language_model_only=self.language_model_only,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds, enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment