Unverified Commit 27f4c2fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Separate out `RendererConfig` from `ModelConfig` (#30145)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a49d813f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import Field, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.model import ModelConfig
from vllm.config.utils import config
from vllm.transformers_utils.gguf_utils import is_gguf
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
@config
@dataclass
class RendererConfig:
"""Configuration for the renderer."""
# NOTE: In reality, this is a required argument.
# We provide a dummy default value here to generate the CLI args.
model_config: SkipValidation[ModelConfig] = None # type: ignore
"""Provides model context to the renderer."""
tokenizer: str = ""
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto"
"""Tokenizer mode:\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins."""
tokenizer_revision: str | None = None
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
skip_tokenizer_init: bool = False
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
io_processor_plugin: str | None = None
"""IOProcessor plugin name to load at model startup."""
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'`"""
allowed_local_media_path: str = ""
"""Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only
be enabled in trusted environments."""
allowed_media_domains: list[str] | None = None
"""If set, only media URLs that belong to this domain can be used for
multi-modal inputs. """
@property
def trust_remote_code(self) -> bool:
return self.model_config.trust_remote_code
def __post_init__(self) -> None:
model_config = self.model_config
# The tokenizer is consistent with the model by default.
if not self.tokenizer:
self.tokenizer = (
ModelConfig.model
if model_config is None
else model_config.original_model
)
if not self.tokenizer_revision:
self.tokenizer_revision = (
ModelConfig.revision if model_config is None else model_config.revision
)
self.original_tokenizer = self.tokenizer
self.tokenizer = maybe_model_redirect(self.original_tokenizer)
self.maybe_pull_tokenizer_for_runai(self.tokenizer)
# Multimodal GGUF models must use original repo for mm processing
is_multimodal_model = (
ModelConfig.is_multimodal_model
if model_config is None
else model_config.is_multimodal_model
)
if is_gguf(self.tokenizer) and is_multimodal_model:
raise ValueError(
"Loading a multimodal GGUF model needs to use original "
"tokenizer. Please specify the unquantized hf model's "
"repo name or path using the --tokenizer argument."
)
def maybe_pull_tokenizer_for_runai(self, tokenizer: str) -> None:
"""Pull tokenizer from Object Storage to temporary directory when needed."""
if not is_runai_obj_uri(tokenizer):
return
object_storage_tokenizer = ObjectStorageModel(url=tokenizer)
object_storage_tokenizer.pull_files(
tokenizer,
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors", "*.pth"],
)
self.tokenizer = object_storage_tokenizer.dir
...@@ -322,16 +322,11 @@ class SpeculativeConfig: ...@@ -322,16 +322,11 @@ class SpeculativeConfig:
self.draft_model_config = ModelConfig( self.draft_model_config = ModelConfig(
model=self.model, model=self.model,
runner="draft", runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.trust_remote_code, trust_remote_code=self.target_model_config.trust_remote_code,
allowed_local_media_path=self.target_model_config.allowed_local_media_path,
allowed_media_domains=self.target_model_config.allowed_media_domains,
dtype=self.target_model_config.dtype, dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed, seed=self.target_model_config.seed,
revision=self.revision, revision=self.revision,
code_revision=self.code_revision, code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.tokenizer_revision,
spec_target_max_model_len=self.target_model_config.max_model_len, spec_target_max_model_len=self.target_model_config.max_model_len,
quantization=self.quantization, quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager, enforce_eager=self.target_model_config.enforce_eager,
......
...@@ -39,6 +39,7 @@ from .lora import LoRAConfig ...@@ -39,6 +39,7 @@ from .lora import LoRAConfig
from .model import ModelConfig from .model import ModelConfig
from .observability import ObservabilityConfig from .observability import ObservabilityConfig
from .parallel import ParallelConfig from .parallel import ParallelConfig
from .renderer import RendererConfig
from .scheduler import SchedulerConfig from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
...@@ -181,6 +182,8 @@ class VllmConfig: ...@@ -181,6 +182,8 @@ class VllmConfig:
# try to download a model # try to download a model
model_config: ModelConfig = Field(default=None) model_config: ModelConfig = Field(default=None)
"""Model configuration.""" """Model configuration."""
renderer_config: RendererConfig = Field(default_factory=RendererConfig)
"""Renderer configuration."""
cache_config: CacheConfig = Field(default_factory=CacheConfig) cache_config: CacheConfig = Field(default_factory=CacheConfig)
"""Cache configuration.""" """Cache configuration."""
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig) parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
...@@ -741,7 +744,7 @@ class VllmConfig: ...@@ -741,7 +744,7 @@ class VllmConfig:
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
self.scheduler_config.max_num_encoder_input_tokens = ( self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.renderer_config)
) )
logger.debug( logger.debug(
"Encoder-decoder model detected: setting " "Encoder-decoder model detected: setting "
...@@ -1186,11 +1189,13 @@ class VllmConfig: ...@@ -1186,11 +1189,13 @@ class VllmConfig:
computed_compile_ranges_split_points computed_compile_ranges_split_points
) )
def recalculate_max_model_len(self, max_model_len: int): def recalculate_max_model_len(self, original_max_model_len: int | None) -> None:
# Can only be called in try_verify_and_update_config # Can only be called during try_verify_and_update_config
model_config = self.model_config self.model_config.recalculate_max_model_len(
max_model_len = model_config.get_and_verify_max_len(max_model_len) original_max_model_len,
self.model_config.max_model_len = max_model_len tokenizer=self.renderer_config.tokenizer,
tokenizer_revision=self.renderer_config.tokenizer_revision,
)
def try_verify_and_update_config(self): def try_verify_and_update_config(self):
if self.model_config is None: if self.model_config is None:
...@@ -1264,11 +1269,11 @@ class VllmConfig: ...@@ -1264,11 +1269,11 @@ class VllmConfig:
return ( return (
f"model={self.model_config.model!r}, " f"model={self.model_config.model!r}, "
f"speculative_config={self.speculative_config!r}, " f"speculative_config={self.speculative_config!r}, "
f"tokenizer={self.model_config.tokenizer!r}, " f"tokenizer={self.renderer_config.tokenizer!r}, "
f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " f"skip_tokenizer_init={self.renderer_config.skip_tokenizer_init}, "
f"tokenizer_mode={self.model_config.tokenizer_mode}, " f"tokenizer_mode={self.renderer_config.tokenizer_mode}, "
f"revision={self.model_config.revision}, " f"revision={self.model_config.revision}, "
f"tokenizer_revision={self.model_config.tokenizer_revision}, " f"tokenizer_revision={self.renderer_config.tokenizer_revision}, "
f"trust_remote_code={self.model_config.trust_remote_code}, " f"trust_remote_code={self.model_config.trust_remote_code}, "
f"dtype={self.model_config.dtype}, " f"dtype={self.model_config.dtype}, "
f"max_seq_len={self.model_config.max_model_len}, " f"max_seq_len={self.model_config.max_model_len}, "
......
...@@ -71,11 +71,11 @@ from vllm.config.model import ( ...@@ -71,11 +71,11 @@ from vllm.config.model import (
ModelDType, ModelDType,
RunnerOption, RunnerOption,
TaskOption, TaskOption,
TokenizerMode,
) )
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.renderer import RendererConfig, TokenizerMode
from vllm.config.scheduler import SchedulerPolicy from vllm.config.scheduler import SchedulerPolicy
from vllm.config.utils import get_field from vllm.config.utils import get_field
from vllm.config.vllm import OptimizationLevel from vllm.config.vllm import OptimizationLevel
...@@ -355,17 +355,12 @@ class EngineArgs: ...@@ -355,17 +355,12 @@ class EngineArgs:
model: str = ModelConfig.model model: str = ModelConfig.model
served_model_name: str | list[str] | None = ModelConfig.served_model_name served_model_name: str | list[str] | None = ModelConfig.served_model_name
tokenizer: str | None = ModelConfig.tokenizer
hf_config_path: str | None = ModelConfig.hf_config_path hf_config_path: str | None = ModelConfig.hf_config_path
runner: RunnerOption = ModelConfig.runner runner: RunnerOption = ModelConfig.runner
convert: ConvertOption = ModelConfig.convert convert: ConvertOption = ModelConfig.convert
task: TaskOption | None = ModelConfig.task task: TaskOption | None = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
download_dir: str | None = LoadConfig.download_dir download_dir: str | None = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: str | LoadFormats = LoadConfig.load_format load_format: str | LoadFormats = LoadConfig.load_format
...@@ -449,7 +444,6 @@ class EngineArgs: ...@@ -449,7 +444,6 @@ class EngineArgs:
code_revision: str | None = ModelConfig.code_revision code_revision: str | None = ModelConfig.code_revision
hf_token: bool | str | None = ModelConfig.hf_token hf_token: bool | str | None = ModelConfig.hf_token
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
tokenizer_revision: str | None = ModelConfig.tokenizer_revision
quantization: QuantizationMethods | None = ModelConfig.quantization quantization: QuantizationMethods | None = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
...@@ -458,9 +452,6 @@ class EngineArgs: ...@@ -458,9 +452,6 @@ class EngineArgs:
) )
enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
MultiModalConfig, "media_io_kwargs"
)
mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
...@@ -474,9 +465,19 @@ class EngineArgs: ...@@ -474,9 +465,19 @@ class EngineArgs:
mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
MultiModalConfig.mm_encoder_attn_backend MultiModalConfig.mm_encoder_attn_backend
) )
io_processor_plugin: str | None = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float = MultiModalConfig.video_pruning_rate video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# Renderer fields
tokenizer: str | None = None
tokenizer_mode: TokenizerMode | str = RendererConfig.tokenizer_mode
tokenizer_revision: str | None = RendererConfig.tokenizer_revision
skip_tokenizer_init: bool = RendererConfig.skip_tokenizer_init
io_processor_plugin: str | None = None
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
RendererConfig, "media_io_kwargs"
)
allowed_local_media_path: str = RendererConfig.allowed_local_media_path
allowed_media_domains: list[str] | None = RendererConfig.allowed_media_domains
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
max_loras: int = LoRAConfig.max_loras max_loras: int = LoRAConfig.max_loras
...@@ -627,25 +628,14 @@ class EngineArgs: ...@@ -627,25 +628,14 @@ class EngineArgs:
model_group.add_argument("--runner", **model_kwargs["runner"]) model_group.add_argument("--runner", **model_kwargs["runner"])
model_group.add_argument("--convert", **model_kwargs["convert"]) model_group.add_argument("--convert", **model_kwargs["convert"])
model_group.add_argument("--task", **model_kwargs["task"], deprecated=True) model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
model_group.add_argument( model_group.add_argument(
"--trust-remote-code", **model_kwargs["trust_remote_code"] "--trust-remote-code", **model_kwargs["trust_remote_code"]
) )
model_group.add_argument("--dtype", **model_kwargs["dtype"]) model_group.add_argument("--dtype", **model_kwargs["dtype"])
model_group.add_argument("--seed", **model_kwargs["seed"]) model_group.add_argument("--seed", **model_kwargs["seed"])
model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"]) model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
model_group.add_argument(
"--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
)
model_group.add_argument(
"--allowed-media-domains", **model_kwargs["allowed_media_domains"]
)
model_group.add_argument("--revision", **model_kwargs["revision"]) model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision", **model_kwargs["code_revision"]) model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
model_group.add_argument(
"--tokenizer-revision", **model_kwargs["tokenizer_revision"]
)
model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"]) model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"]) model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
...@@ -657,9 +647,6 @@ class EngineArgs: ...@@ -657,9 +647,6 @@ class EngineArgs:
model_group.add_argument( model_group.add_argument(
"--disable-cascade-attn", **model_kwargs["disable_cascade_attn"] "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
) )
model_group.add_argument(
"--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
)
model_group.add_argument( model_group.add_argument(
"--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"] "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
) )
...@@ -698,8 +685,34 @@ class EngineArgs: ...@@ -698,8 +685,34 @@ class EngineArgs:
model_group.add_argument( model_group.add_argument(
"--logits-processors", **model_kwargs["logits_processors"] "--logits-processors", **model_kwargs["logits_processors"]
) )
model_group.add_argument(
"--io-processor-plugin", **model_kwargs["io_processor_plugin"] # Renderer arguments
renderer_kwargs = get_kwargs(RendererConfig)
renderer_group = parser.add_argument_group(
title="RendererConfig",
description=RendererConfig.__doc__,
)
renderer_group.add_argument("--tokenizer", **renderer_kwargs["tokenizer"])
renderer_group.add_argument(
"--tokenizer-mode", **renderer_kwargs["tokenizer_mode"]
)
renderer_group.add_argument(
"--tokenizer-revision", **renderer_kwargs["tokenizer_revision"]
)
renderer_group.add_argument(
"--skip-tokenizer-init", **renderer_kwargs["skip_tokenizer_init"]
)
renderer_group.add_argument(
"--media-io-kwargs", **renderer_kwargs["media_io_kwargs"]
)
renderer_group.add_argument(
"--allowed-local-media-path", **renderer_kwargs["allowed_local_media_path"]
)
renderer_group.add_argument(
"--allowed-media-domains", **renderer_kwargs["allowed_media_domains"]
)
renderer_group.add_argument(
"--io-processor-plugin", **renderer_kwargs["io_processor_plugin"]
) )
# Model loading arguments # Model loading arguments
...@@ -949,9 +962,6 @@ class EngineArgs: ...@@ -949,9 +962,6 @@ class EngineArgs:
multimodal_group.add_argument( multimodal_group.add_argument(
"--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"] "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
) )
multimodal_group.add_argument(
"--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
)
multimodal_group.add_argument( multimodal_group.add_argument(
"--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"] "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
) )
...@@ -1255,18 +1265,13 @@ class EngineArgs: ...@@ -1255,18 +1265,13 @@ class EngineArgs:
runner=self.runner, runner=self.runner,
convert=self.convert, convert=self.convert,
task=self.task, task=self.task,
tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
dtype=self.dtype, dtype=self.dtype,
seed=self.seed, seed=self.seed,
revision=self.revision, revision=self.revision,
code_revision=self.code_revision, code_revision=self.code_revision,
hf_token=self.hf_token, hf_token=self.hf_token,
hf_overrides=self.hf_overrides, hf_overrides=self.hf_overrides,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
quantization=self.quantization, quantization=self.quantization,
enforce_eager=self.enforce_eager, enforce_eager=self.enforce_eager,
...@@ -1274,13 +1279,11 @@ class EngineArgs: ...@@ -1274,13 +1279,11 @@ class EngineArgs:
logprobs_mode=self.logprobs_mode, logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn, disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds, enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds, enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling, skip_mm_profiling=self.skip_mm_profiling,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
...@@ -1298,7 +1301,6 @@ class EngineArgs: ...@@ -1298,7 +1301,6 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype, override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors, logits_processors=self.logits_processors,
video_pruning_rate=self.video_pruning_rate, video_pruning_rate=self.video_pruning_rate,
io_processor_plugin=self.io_processor_plugin,
) )
def validate_tensorizer_args(self): def validate_tensorizer_args(self):
...@@ -1394,9 +1396,25 @@ class EngineArgs: ...@@ -1394,9 +1396,25 @@ class EngineArgs:
) )
model_config = self.create_model_config() model_config = self.create_model_config()
self.model = model_config.model renderer_config = RendererConfig(
self.tokenizer = model_config.tokenizer model_config=model_config,
tokenizer=self.tokenizer or "",
tokenizer_mode=self.tokenizer_mode,
tokenizer_revision=self.tokenizer_revision,
skip_tokenizer_init=self.skip_tokenizer_init,
io_processor_plugin=self.io_processor_plugin,
media_io_kwargs=self.media_io_kwargs,
allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
)
model_config.recalculate_max_model_len(
model_config.original_max_model_len,
tokenizer=renderer_config.tokenizer,
tokenizer_revision=renderer_config.tokenizer_revision,
)
self.model = model_config.model
self._check_feature_supported(model_config) self._check_feature_supported(model_config)
self._set_default_chunked_prefill_and_prefix_caching_args(model_config) self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
self._set_default_max_num_seqs_and_batched_tokens_args( self._set_default_max_num_seqs_and_batched_tokens_args(
...@@ -1768,6 +1786,7 @@ class EngineArgs: ...@@ -1768,6 +1786,7 @@ class EngineArgs:
) )
config = VllmConfig( config = VllmConfig(
model_config=model_config, model_config=model_config,
renderer_config=renderer_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
......
...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any from typing import Any
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, RendererConfig, VllmConfig
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
...@@ -22,6 +22,7 @@ class EngineClient(ABC): ...@@ -22,6 +22,7 @@ class EngineClient(ABC):
"""Protocol class for Clients to Engine""" """Protocol class for Clients to Engine"""
vllm_config: VllmConfig vllm_config: VllmConfig
renderer_config: RendererConfig
model_config: ModelConfig model_config: ModelConfig
input_processor: InputProcessor input_processor: InputProcessor
io_processor: IOProcessor | None io_processor: IOProcessor | None
......
...@@ -44,7 +44,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, Processor ...@@ -44,7 +44,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, Processor
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
from vllm import envs from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig, RendererConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
...@@ -452,9 +452,10 @@ This is needed because `lru_cache` does not cache when an exception happens. ...@@ -452,9 +452,10 @@ This is needed because `lru_cache` does not cache when an exception happens.
def _try_get_processor_chat_template( def _try_get_processor_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
model_config: ModelConfig, *,
trust_remote_code: bool,
) -> str | None: ) -> str | None:
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) cache_key = (tokenizer.name_or_path, trust_remote_code)
if cache_key in _PROCESSOR_CHAT_TEMPLATES: if cache_key in _PROCESSOR_CHAT_TEMPLATES:
return _PROCESSOR_CHAT_TEMPLATES[cache_key] return _PROCESSOR_CHAT_TEMPLATES[cache_key]
...@@ -466,7 +467,7 @@ def _try_get_processor_chat_template( ...@@ -466,7 +467,7 @@ def _try_get_processor_chat_template(
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
ProcessorMixin, ProcessorMixin,
), ),
trust_remote_code=model_config.trust_remote_code, trust_remote_code=trust_remote_code,
) )
if ( if (
isinstance(processor, ProcessorMixin) isinstance(processor, ProcessorMixin)
...@@ -499,7 +500,10 @@ def resolve_hf_chat_template( ...@@ -499,7 +500,10 @@ def resolve_hf_chat_template(
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None: if tools is None:
chat_template = _try_get_processor_chat_template(tokenizer, model_config) chat_template = _try_get_processor_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
if chat_template is not None: if chat_template is not None:
return chat_template return chat_template
...@@ -513,10 +517,10 @@ def resolve_hf_chat_template( ...@@ -513,10 +517,10 @@ def resolve_hf_chat_template(
exc_info=True, exc_info=True,
) )
# 4th priority: Predefined fallbacks # 4th priority: Predefined fallbacks]
path = get_chat_template_fallback_path( path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type, model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer, tokenizer_name_or_path=tokenizer.name_or_path,
) )
if path is not None: if path is not None:
logger.info_once( logger.info_once(
...@@ -538,14 +542,14 @@ def _resolve_chat_template_content_format( ...@@ -538,14 +542,14 @@ def _resolve_chat_template_content_format(
tools: list[dict[str, Any]] | None, tools: list[dict[str, Any]] | None,
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
*, *,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
tools=tools, tools=tools,
model_config=model_config, model_config=renderer_config.model_config,
) )
else: else:
hf_chat_template = None hf_chat_template = None
...@@ -595,7 +599,7 @@ def resolve_chat_template_content_format( ...@@ -595,7 +599,7 @@ def resolve_chat_template_content_format(
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
*, *,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if given_format != "auto": if given_format != "auto":
return given_format return given_format
...@@ -604,7 +608,7 @@ def resolve_chat_template_content_format( ...@@ -604,7 +608,7 @@ def resolve_chat_template_content_format(
chat_template, chat_template,
tools, tools,
tokenizer, tokenizer,
model_config=model_config, renderer_config=renderer_config,
) )
_log_chat_template_content_format( _log_chat_template_content_format(
...@@ -627,32 +631,32 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -627,32 +631,32 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt. maximum per prompt.
""" """
def __init__(self, model_config: ModelConfig): def __init__(self, renderer_config: RendererConfig):
super().__init__() super().__init__()
self._model_config = model_config self._renderer_config = renderer_config
self._items_by_modality = defaultdict[str, list[_T | None]](list) self._items_by_modality = defaultdict[str, list[_T | None]](list)
self._uuids_by_modality = defaultdict[str, list[str | None]](list) self._uuids_by_modality = defaultdict[str, list[str | None]](list)
@property @property
def model_config(self) -> ModelConfig: def renderer_config(self) -> RendererConfig:
return self._model_config return self._renderer_config
@cached_property @cached_property
def model_cls(self) -> type[SupportsMultiModal]: def model_cls(self) -> type[SupportsMultiModal]:
from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config) model_cls = get_model_cls(self.renderer_config.model_config)
return cast(type[SupportsMultiModal], model_cls) return cast(type[SupportsMultiModal], model_cls)
@property @property
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._renderer_config.allowed_local_media_path
@property @property
def allowed_media_domains(self): def allowed_media_domains(self):
return self._model_config.allowed_media_domains return self._renderer_config.allowed_media_domains
@property @property
def mm_registry(self): def mm_registry(self):
...@@ -660,7 +664,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -660,7 +664,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@cached_property @cached_property
def mm_processor(self): def mm_processor(self):
return self.mm_registry.create_processor(self.model_config) return self.mm_registry.create_processor(self.renderer_config)
def add( def add(
self, self,
...@@ -851,19 +855,20 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -851,19 +855,20 @@ class MultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR, envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs, media_io_kwargs=self.renderer_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
@property
def renderer_config(self) -> RendererConfig:
return self._tracker.renderer_config
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._tracker.model_config return self.renderer_config.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None image = self._connector.fetch_image(image_url) if image_url else None
...@@ -963,18 +968,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -963,18 +968,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR, envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs, media_io_kwargs=self.renderer_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
@property
def renderer_config(self) -> RendererConfig:
return self._tracker.renderer_config
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._tracker.model_config return self.renderer_config.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None image_coro = self._connector.fetch_image_async(image_url) if image_url else None
...@@ -1604,15 +1611,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: ...@@ -1604,15 +1611,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages( def parse_chat_messages(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, renderer_config: RendererConfig,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
MultiModalDataDict | None, MultiModalDataDict | None,
MultiModalUUIDDict | None, MultiModalUUIDDict | None,
]: ]:
model_config = renderer_config.model_config
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config) mm_tracker = MultiModalItemTracker(renderer_config)
for msg in messages: for msg in messages:
sub_messages = _parse_chat_message_content( sub_messages = _parse_chat_message_content(
...@@ -1635,15 +1644,17 @@ def parse_chat_messages( ...@@ -1635,15 +1644,17 @@ def parse_chat_messages(
def parse_chat_messages_futures( def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, renderer_config: RendererConfig,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
Awaitable[MultiModalDataDict | None], Awaitable[MultiModalDataDict | None],
MultiModalUUIDDict | None, MultiModalUUIDDict | None,
]: ]:
model_config = renderer_config.model_config
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config) mm_tracker = AsyncMultiModalItemTracker(renderer_config)
for msg in messages: for msg in messages:
sub_messages = _parse_chat_message_content( sub_messages = _parse_chat_message_content(
...@@ -1748,14 +1759,14 @@ def apply_hf_chat_template( ...@@ -1748,14 +1759,14 @@ def apply_hf_chat_template(
chat_template: str | None, chat_template: str | None,
tools: list[dict[str, Any]] | None, tools: list[dict[str, Any]] | None,
*, *,
model_config: ModelConfig, renderer_config: RendererConfig,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
tools=tools, tools=tools,
model_config=model_config, model_config=renderer_config.model_config,
) )
if hf_chat_template is None: if hf_chat_template is None:
......
...@@ -29,8 +29,8 @@ from vllm.config.model import ( ...@@ -29,8 +29,8 @@ from vllm.config.model import (
HfOverrides, HfOverrides,
ModelDType, ModelDType,
RunnerOption, RunnerOption,
TokenizerMode,
) )
from vllm.config.renderer import TokenizerMode
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
...@@ -343,6 +343,7 @@ class LLM: ...@@ -343,6 +343,7 @@ class LLM:
logger.info("Supported tasks: %s", supported_tasks) logger.info("Supported tasks: %s", supported_tasks)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.renderer_config = self.llm_engine.renderer_config
self.model_config = self.llm_engine.model_config self.model_config = self.llm_engine.model_config
self.input_processor = self.llm_engine.input_processor self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
...@@ -808,13 +809,13 @@ class LLM: ...@@ -808,13 +809,13 @@ class LLM:
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
model_config = self.model_config renderer_config = self.renderer_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tools, tools,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
model_config=model_config, renderer_config=renderer_config,
) )
_chat_template_kwargs: dict[str, Any] = dict( _chat_template_kwargs: dict[str, Any] = dict(
...@@ -833,7 +834,7 @@ class LLM: ...@@ -833,7 +834,7 @@ class LLM:
# the chat message parsing for it. # the chat message parsing for it.
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
msgs, msgs,
model_config, renderer_config,
content_format=resolved_content_format, content_format=resolved_content_format,
) )
...@@ -847,7 +848,7 @@ class LLM: ...@@ -847,7 +848,7 @@ class LLM:
prompt_str = apply_hf_chat_template( prompt_str = apply_hf_chat_template(
tokenizer=tokenizer, tokenizer=tokenizer,
conversation=conversation, conversation=conversation,
model_config=model_config, renderer_config=renderer_config,
**_chat_template_kwargs, **_chat_template_kwargs,
) )
# Special tokens are already included in chat templates so # Special tokens are already included in chat templates so
...@@ -1290,6 +1291,7 @@ class LLM: ...@@ -1290,6 +1291,7 @@ class LLM:
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
renderer_config = self.renderer_config
model_config = self.model_config model_config = self.model_config
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
...@@ -1317,7 +1319,7 @@ class LLM: ...@@ -1317,7 +1319,7 @@ class LLM:
for q, d in input_pairs: for q, d in input_pairs:
_, engine_prompt = get_score_prompt( _, engine_prompt = get_score_prompt(
model_config=model_config, renderer_config=renderer_config,
data_1=q, data_1=q,
data_2=d, data_2=d,
tokenizer=tokenizer, tokenizer=tokenizer,
......
...@@ -1099,7 +1099,7 @@ async def init_app_state( ...@@ -1099,7 +1099,7 @@ async def init_app_state(
logger.info("Supported tasks: %s", supported_tasks) logger.info("Supported tasks: %s", supported_tasks)
resolved_chat_template = await process_chat_template( resolved_chat_template = await process_chat_template(
args.chat_template, engine_client, vllm_config.model_config args.chat_template, engine_client, vllm_config.renderer_config
) )
if args.tool_server == "demo": if args.tool_server == "demo":
......
...@@ -122,7 +122,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -122,7 +122,7 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init: if self.renderer_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = await self.engine_client.get_tokenizer() tokenizer = await self.engine_client.get_tokenizer()
......
...@@ -291,6 +291,7 @@ class OpenAIServing: ...@@ -291,6 +291,7 @@ class OpenAIServing:
self.input_processor = self.models.input_processor self.input_processor = self.models.input_processor
self.io_processor = self.models.io_processor self.io_processor = self.models.io_processor
self.renderer_config = self.models.renderer_config
self.model_config = self.models.model_config self.model_config = self.models.model_config
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
...@@ -1100,18 +1101,18 @@ class OpenAIServing: ...@@ -1100,18 +1101,18 @@ class OpenAIServing:
Sequence[RequestPrompt], Sequence[RequestPrompt],
list[EngineTokensPrompt], list[EngineTokensPrompt],
]: ]:
model_config = self.model_config renderer_config = self.renderer_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tool_dicts, tool_dicts,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
model_config=model_config, renderer_config=renderer_config,
) )
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages, messages,
model_config, renderer_config,
content_format=resolved_content_format, content_format=resolved_content_format,
) )
...@@ -1138,14 +1139,14 @@ class OpenAIServing: ...@@ -1138,14 +1139,14 @@ class OpenAIServing:
request_prompt = tokenizer.apply_chat_template( request_prompt = tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
model_config=model_config, model_config=renderer_config.model_config,
**_chat_template_kwargs, **_chat_template_kwargs,
) )
else: else:
request_prompt = apply_hf_chat_template( request_prompt = apply_hf_chat_template(
tokenizer=tokenizer, tokenizer=tokenizer,
conversation=conversation, conversation=conversation,
model_config=model_config, renderer_config=renderer_config,
**_chat_template_kwargs, **_chat_template_kwargs,
) )
......
...@@ -71,6 +71,7 @@ class OpenAIServingModels: ...@@ -71,6 +71,7 @@ class OpenAIServingModels:
self.input_processor = self.engine_client.input_processor self.input_processor = self.engine_client.input_processor
self.io_processor = self.engine_client.io_processor self.io_processor = self.engine_client.io_processor
self.renderer_config = self.engine_client.renderer_config
self.model_config = self.engine_client.model_config self.model_config = self.engine_client.model_config
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
......
...@@ -91,7 +91,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -91,7 +91,7 @@ class OpenAISpeechToText(OpenAIServing):
self.task_type = task_type self.task_type = task_type
self.asr_config = self.model_cls.get_speech_to_text_config( self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type self.renderer_config, task_type
) )
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
...@@ -101,8 +101,8 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -101,8 +101,8 @@ class OpenAISpeechToText(OpenAIServing):
self.tokenizer = cast( self.tokenizer = cast(
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
get_tokenizer( get_tokenizer(
tokenizer_name=self.model_config.tokenizer, tokenizer_name=self.renderer_config.tokenizer,
tokenizer_mode=self.model_config.tokenizer_mode, tokenizer_mode=self.renderer_config.tokenizer_mode,
), ),
) )
...@@ -154,7 +154,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -154,7 +154,7 @@ class OpenAISpeechToText(OpenAIServing):
prompt = self.model_cls.get_generation_prompt( prompt = self.model_cls.get_generation_prompt(
audio=chunk, audio=chunk,
stt_config=self.asr_config, stt_config=self.asr_config,
model_config=self.model_config, renderer_config=self.renderer_config,
language=language, language=language,
task_type=self.task_type, task_type=self.task_type,
request_prompt=request.prompt, request_prompt=request.prompt,
...@@ -428,7 +428,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -428,7 +428,7 @@ class OpenAISpeechToText(OpenAIServing):
if res.prompt_token_ids is not None: if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids) num_prompt_tokens = len(res.prompt_token_ids)
if audio_tokens := self.model_cls.get_num_audio_tokens( if audio_tokens := self.model_cls.get_num_audio_tokens(
audio_duration_s, self.asr_config, self.model_config audio_duration_s, self.asr_config, self.renderer_config
): ):
num_prompt_tokens += audio_tokens num_prompt_tokens += audio_tokens
......
...@@ -94,7 +94,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -94,7 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init: if self.renderer_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = await self.engine_client.get_tokenizer() tokenizer = await self.engine_client.get_tokenizer()
......
...@@ -160,10 +160,8 @@ class ServingScores(OpenAIServing): ...@@ -160,10 +160,8 @@ class ServingScores(OpenAIServing):
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
) -> tuple[str, TokensPrompt]: ) -> tuple[str, TokensPrompt]:
model_config = self.model_config
full_prompt, engine_prompt = get_score_prompt( full_prompt, engine_prompt = get_score_prompt(
model_config=model_config, renderer_config=self.renderer_config,
data_1=data_1, data_1=data_1,
data_2=data_2, data_2=data_2,
tokenizer=tokenizer, tokenizer=tokenizer,
......
...@@ -5,7 +5,7 @@ from typing import Any, TypeAlias, cast ...@@ -5,7 +5,7 @@ from typing import Any, TypeAlias, cast
from torch.nn import CosineSimilarity from torch.nn import CosineSimilarity
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig, RendererConfig
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
BaseMultiModalItemTracker, BaseMultiModalItemTracker,
ChatCompletionContentPartImageEmbedsParam, ChatCompletionContentPartImageEmbedsParam,
...@@ -88,9 +88,9 @@ def _validate_score_input_lens( ...@@ -88,9 +88,9 @@ def _validate_score_input_lens(
def parse_score_data( def parse_score_data(
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> tuple[str, str, MultiModalDataDict | None]: ) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config) mm_tracker = MultiModalItemTracker(renderer_config)
content_1 = _parse_score_content(data_1, mm_tracker) content_1 = _parse_score_content(data_1, mm_tracker)
content_2 = _parse_score_content(data_2, mm_tracker) content_2 = _parse_score_content(data_2, mm_tracker)
...@@ -176,7 +176,7 @@ def post_process_tokens( ...@@ -176,7 +176,7 @@ def post_process_tokens(
def get_score_prompt( def get_score_prompt(
model_config: ModelConfig, renderer_config: RendererConfig,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any], tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
...@@ -185,11 +185,14 @@ def get_score_prompt( ...@@ -185,11 +185,14 @@ def get_score_prompt(
prompt_1, prompt_2, mm_data = parse_score_data( prompt_1, prompt_2, mm_data = parse_score_data(
data_1, data_1,
data_2, data_2,
model_config, renderer_config,
) )
from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.model_loader import get_model_cls
model_config = renderer_config.model_config
model = get_model_cls(model_config) model = get_model_cls(model_config)
if supports_score_template(model): if supports_score_template(model):
full_prompt = apply_score_template(model_config, prompt_1, prompt_2) full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
......
...@@ -13,7 +13,7 @@ from fastapi import Request ...@@ -13,7 +13,7 @@ from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks from starlette.background import BackgroundTask, BackgroundTasks
from vllm.config import ModelConfig from vllm.config import RendererConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
...@@ -288,7 +288,7 @@ def process_lora_modules( ...@@ -288,7 +288,7 @@ def process_lora_modules(
async def process_chat_template( async def process_chat_template(
args_chat_template: Path | str | None, args_chat_template: Path | str | None,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> str | None: ) -> str | None:
resolved_chat_template = load_chat_template(args_chat_template) resolved_chat_template = load_chat_template(args_chat_template)
if resolved_chat_template is not None: if resolved_chat_template is not None:
...@@ -305,7 +305,7 @@ async def process_chat_template( ...@@ -305,7 +305,7 @@ async def process_chat_template(
tokenizer=tokenizer, tokenizer=tokenizer,
chat_template=None, chat_template=None,
tools=None, tools=None,
model_config=model_config, model_config=renderer_config.model_config,
) )
if hf_chat_template != resolved_chat_template: if hf_chat_template != resolved_chat_template:
...@@ -314,6 +314,6 @@ async def process_chat_template( ...@@ -314,6 +314,6 @@ async def process_chat_template(
"It is different from official chat template '%s'. " "It is different from official chat template '%s'. "
"This discrepancy may lead to performance degradation.", "This discrepancy may lead to performance degradation.",
resolved_chat_template, resolved_chat_template,
model_config.model, renderer_config.model_config.model,
) )
return resolved_chat_template return resolved_chat_template
...@@ -6,7 +6,7 @@ from typing import Any, cast ...@@ -6,7 +6,7 @@ from typing import Any, cast
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import RendererConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
...@@ -45,14 +45,15 @@ logger = init_logger(__name__) ...@@ -45,14 +45,15 @@ logger = init_logger(__name__)
class InputPreprocessor: class InputPreprocessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, renderer_config: RendererConfig,
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None, mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = model_config self.renderer_config = renderer_config
self.model_config = renderer_config.model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
...@@ -231,7 +232,7 @@ class InputPreprocessor: ...@@ -231,7 +232,7 @@ class InputPreprocessor:
def _get_mm_processor(self) -> BaseMultiModalProcessor: def _get_mm_processor(self) -> BaseMultiModalProcessor:
if not hasattr(self, "_mm_processor"): if not hasattr(self, "_mm_processor"):
self._mm_processor = self.mm_registry.create_processor( self._mm_processor = self.mm_registry.create_processor(
self.model_config, self.renderer_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
cache=self.mm_processor_cache, cache=self.mm_processor_cache,
) )
......
...@@ -415,7 +415,7 @@ def load_weights_using_from_2_way_softmax( ...@@ -415,7 +415,7 @@ def load_weights_using_from_2_way_softmax(
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
model_config = model.vllm_config.model_config renderer_config = model.vllm_config.renderer_config
quant_config = model.vllm_config.quant_config quant_config = model.vllm_config.quant_config
text_config = model.config.get_text_config() text_config = model.config.get_text_config()
...@@ -447,10 +447,10 @@ def load_weights_using_from_2_way_softmax( ...@@ -447,10 +447,10 @@ def load_weights_using_from_2_way_softmax(
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model_config.tokenizer, renderer_config.tokenizer,
revision=model_config.tokenizer_revision, revision=renderer_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=renderer_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=renderer_config.trust_remote_code,
) )
false_id = tokenizer.convert_tokens_to_ids(tokens[0]) false_id = tokenizer.convert_tokens_to_ids(tokens[0])
...@@ -473,7 +473,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ...@@ -473,7 +473,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
model_config = model.vllm_config.model_config renderer_config = model.vllm_config.renderer_config
quant_config = model.vllm_config.quant_config quant_config = model.vllm_config.quant_config
text_config = model.config.get_text_config() text_config = model.config.get_text_config()
...@@ -501,10 +501,10 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ...@@ -501,10 +501,10 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model_config.tokenizer, renderer_config.tokenizer,
revision=model_config.tokenizer_revision, revision=renderer_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=renderer_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=renderer_config.trust_remote_code,
) )
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
......
...@@ -377,8 +377,8 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -377,8 +377,8 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.projector_config = config.projector_config self.projector_config = config.projector_config
self.text_config = config.text_config self.text_config = config.text_config
model_config = vllm_config.model_config renderer_config = vllm_config.renderer_config
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(renderer_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.sam_model = build_sam_vit_b() self.sam_model = build_sam_vit_b()
......
...@@ -370,8 +370,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -370,8 +370,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.projector_config = config.projector_config self.projector_config = config.projector_config
self.text_config = config.text_config self.text_config = config.text_config
model_config = vllm_config.model_config renderer_config = vllm_config.renderer_config
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(renderer_config)
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
self.vision = self._init_vision_module( self.vision = self._init_vision_module(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment