Unverified Commit 64a9c252 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[UX] Add `--language-model-only` for hybrid models (#34120)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent d0d97e29
......@@ -297,6 +297,7 @@ class ModelConfig:
multimodal_config: MultiModalConfig | None = None
"""Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`."""
language_model_only: InitVar[bool] = False
limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None
enable_mm_embeds: InitVar[bool | None] = None
media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None
......@@ -411,6 +412,7 @@ class ModelConfig:
def __post_init__(
self,
# Multimodal config init vars
language_model_only: bool,
limit_mm_per_prompt: dict[str, int | dict[str, int]] | None,
enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None,
......@@ -576,6 +578,7 @@ class ModelConfig:
mm_encoder_tp_mode = "weights"
mm_config_kwargs = dict(
language_model_only=language_model_only,
limit_per_prompt=limit_mm_per_prompt,
enable_mm_embeds=enable_mm_embeds,
media_io_kwargs=media_io_kwargs,
......
......@@ -54,6 +54,10 @@ DummyOptions: TypeAlias = (
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
language_model_only: bool = False
"""If True, disables all multimodal inputs by setting all modality limits
to 0. Equivalent to setting --limit-mm-per-prompt to 0 for every
modality."""
limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict)
"""The maximum number of input items and options allowed per
prompt for each modality.
......@@ -215,6 +219,7 @@ class MultiModalConfig:
the final hidden states.
"""
factors: list[Any] = [
self.language_model_only,
self.mm_encoder_attn_backend.name
if self.mm_encoder_attn_backend is not None
else None,
......@@ -228,6 +233,9 @@ class MultiModalConfig:
Get the maximum number of input items allowed per prompt
for the given modality (backward compatible).
"""
if self.language_model_only:
return 0
limit_data = self.limit_per_prompt.get(modality)
if limit_data is None:
......
......@@ -454,6 +454,7 @@ class EngineArgs:
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
language_model_only: bool = MultiModalConfig.language_model_only
limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
MultiModalConfig, "limit_per_prompt"
)
......@@ -975,6 +976,9 @@ class EngineArgs:
title="MultiModalConfig",
description=MultiModalConfig.__doc__,
)
multimodal_group.add_argument(
"--language-model-only", **multimodal_kwargs["language_model_only"]
)
multimodal_group.add_argument(
"--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
)
......@@ -1291,6 +1295,7 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
language_model_only=self.language_model_only,
limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment