Unverified Commit 27f4c2fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Separate out `RendererConfig` from `ModelConfig` (#30145)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a49d813f
......@@ -18,7 +18,7 @@ from transformers.models.gemma3n import (
)
from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config import RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
......@@ -760,7 +760,7 @@ class Gemma3nForConditionalGeneration(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
......@@ -798,7 +798,9 @@ class Gemma3nForConditionalGeneration(
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig:
return SpeechToTextConfig(
# Let's set this to 30 as suggested in the docs for now, although
......
......@@ -34,7 +34,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config import CacheConfig, RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
......@@ -840,7 +840,7 @@ class GraniteSpeechForConditionalGeneration(
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
renderer_config: RendererConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
......@@ -861,7 +861,7 @@ class GraniteSpeechForConditionalGeneration(
else:
raise ValueError(f"Unsupported task type {task_type}")
tokenizer = cached_tokenizer_from_config(model_config)
tokenizer = cached_tokenizer_from_config(renderer_config)
chat = [dict(role="user", content=user_prompt)]
prompt = tokenizer.apply_chat_template(
chat,
......@@ -882,10 +882,10 @@ class GraniteSpeechForConditionalGeneration(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
) -> int | None:
"""Get the number of audio tokens for an audio duration in sec."""
processor = cached_processor_from_config(model_config)
processor = cached_processor_from_config(renderer_config)
hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
proj_win_size = processor.audio_processor.projector_window_size
ds_rate = processor.audio_processor.projector_downsample_rate
......@@ -903,7 +903,9 @@ class GraniteSpeechForConditionalGeneration(
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig:
"""Get the stt config for this model."""
# Default settings are reasonable for this model and we don't currently
......
......@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config import RendererConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
......@@ -29,12 +29,12 @@ logger = init_logger(__name__)
class GritLMMeanPool(nn.Module):
"""As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig):
def __init__(self, renderer_config: RendererConfig):
super().__init__()
self.model_config = model_config
self.renderer_config = renderer_config
tokenizer = cached_tokenizer_from_config(self.model_config)
tokenizer = cached_tokenizer_from_config(self.renderer_config)
# Collect the tokens needed for pattern matching.
# "▁<" is different from "_<". The former uses "▁" to indicate that
......@@ -174,10 +174,10 @@ class GritLMMeanPool(nn.Module):
class GritLMPooler(Pooler):
def __init__(self, model_config: ModelConfig):
def __init__(self, renderer_config: RendererConfig):
super().__init__()
self.pooling = GritLMMeanPool(model_config)
self.pooling = GritLMMeanPool(renderer_config)
self.head = PoolerHead(PoolerNormalize())
def get_supported_tasks(self) -> Set[PoolingTask]:
......@@ -238,6 +238,6 @@ class GritLM(LlamaForCausalLM):
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config),
"embed": GritLMPooler(vllm_config.renderer_config),
}
)
......@@ -19,7 +19,7 @@ from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.config import RendererConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
......@@ -887,7 +887,7 @@ class SupportsTranscription(Protocol):
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
......@@ -930,7 +930,9 @@ class SupportsTranscription(Protocol):
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
cls,
renderer_config: RendererConfig,
task_type: Literal["transcribe", "translate"],
) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model."""
...
......@@ -940,7 +942,7 @@ class SupportsTranscription(Protocol):
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
) -> int | None:
"""
Map from audio duration to number of audio tokens produced by the ASR
......
......@@ -182,7 +182,7 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
hf_processor.video_processor = cached_video_processor_from_config(
self.ctx.model_config,
self.ctx.renderer_config,
processor_cls=InternVLVideoProcessor,
size=hf_processor.image_processor.size,
**kwargs,
......
......@@ -1169,16 +1169,17 @@ class NemotronH_Nano_VL_V2(
self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
self.config = config
self.model_config = vllm_config.model_config
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self._img_start_token_ids = tokenizer.encode(
self._tokenizer = cached_tokenizer_from_config(vllm_config.renderer_config)
self._img_start_token_ids = self._tokenizer.encode(
IMG_START, add_special_tokens=False
)
self._img_end_token_ids = tokenizer.encode(IMG_END, add_special_tokens=False)
self._img_context_token_ids = tokenizer.encode(
self._img_end_token_ids = self._tokenizer.encode(
IMG_END, add_special_tokens=False
)
self._img_context_token_ids = self._tokenizer.encode(
IMG_CONTEXT, add_special_tokens=False
)
......@@ -1364,7 +1365,7 @@ class NemotronH_Nano_VL_V2(
input_embeds for the LLM.
"""
device = video_embeddings.device
tokenizer = cached_tokenizer_from_config(self.model_config)
tokenizer = self._tokenizer
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
......
......@@ -347,7 +347,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
def get_image_processor(self, **kwargs: object):
return cached_image_processor_from_config(
self.ctx.model_config,
self.ctx.renderer_config,
**kwargs,
)
......
......@@ -193,7 +193,7 @@ class PixtralProcessorAdapter:
class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
tokenizer = cached_tokenizer_from_config(self.ctx.renderer_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
......
......@@ -20,7 +20,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config import RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
......@@ -176,7 +176,7 @@ class VoxtralProcessorAdapter:
class VoxtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
tokenizer = cached_tokenizer_from_config(self.ctx.renderer_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
......@@ -339,7 +339,7 @@ class VoxtralForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.tokenizer = cached_tokenizer_from_config(vllm_config.renderer_config)
# update quant config to so that ignored module and target module names
# match the vLLM model names
......@@ -450,9 +450,11 @@ class VoxtralForConditionalGeneration(
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig:
tokenizer = cached_tokenizer_from_config(model_config)
tokenizer = cached_tokenizer_from_config(renderer_config)
audio_config = tokenizer.instruct.audio_encoder.audio_config
max_audio_clip_s = audio_config.chunk_length_s
sample_rate = audio_config.sampling_rate
......@@ -468,17 +470,17 @@ class VoxtralForConditionalGeneration(
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
renderer_config: RendererConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
tokenizer = cached_tokenizer_from_config(renderer_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest(
model=model_config.model,
model=renderer_config.model_config.model,
audio=RawAudio.from_audio(audio),
language=language,
)
......@@ -494,14 +496,14 @@ class VoxtralForConditionalGeneration(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
) -> int | None:
"""
Map from audio duration to number of audio tokens produced by the ASR
model, without running a forward pass.
This is used for estimating the amount of processing for this audio.
"""
tokenizer = cached_tokenizer_from_config(model_config)
tokenizer = cached_tokenizer_from_config(renderer_config)
adapter = VoxtralProcessorAdapter(tokenizer)
return adapter.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate)
......
......@@ -19,7 +19,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config import CacheConfig, RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
......@@ -811,7 +811,7 @@ class WhisperForConditionalGeneration(
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
renderer_config: RendererConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
......@@ -847,9 +847,11 @@ class WhisperForConditionalGeneration(
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
processor = cached_processor_from_config(renderer_config)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
......@@ -861,9 +863,9 @@ class WhisperForConditionalGeneration(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
renderer_config: RendererConfig,
) -> int | None:
processor = cached_processor_from_config(model_config)
processor = cached_processor_from_config(renderer_config)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
# NOTE(NickLucche) user can't pass encoder
......
......@@ -31,7 +31,7 @@ from .inputs import (
)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, RendererConfig, VllmConfig
from .processing import ResolvedPromptUpdate
from .registry import MultiModalRegistry
......@@ -561,13 +561,13 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
def _enable_processor_cache(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
if not mm_registry.supports_multimodal_inputs(renderer_config):
return False
mm_config = model_config.get_multimodal_config()
mm_config = renderer_config.model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
......@@ -599,7 +599,7 @@ def processor_cache_from_config(
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
if not _enable_processor_cache(vllm_config.renderer_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
......@@ -611,14 +611,14 @@ def processor_cache_from_config(
def processor_only_cache_from_config(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
if not _enable_processor_cache(renderer_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
return MultiModalProcessorOnlyCache(renderer_config.model_config)
class BaseMultiModalReceiverCache(
......@@ -787,7 +787,7 @@ def engine_receiver_cache_from_config(
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
if not _enable_processor_cache(vllm_config.renderer_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
......@@ -809,9 +809,7 @@ def worker_receiver_cache_from_config(
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="shm".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
if not _enable_processor_cache(vllm_config.renderer_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
......
......@@ -23,7 +23,7 @@ import torch
from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
......@@ -53,7 +53,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
......@@ -63,6 +63,7 @@ else:
ProcessorMixin = object
ModelConfig = object
RendererConfig = object
BaseMultiModalProcessorCache = object
......@@ -945,12 +946,29 @@ class InputProcessingContext:
modify the inputs.
"""
model_config: ModelConfig
"""The configuration of the model."""
renderer_config: RendererConfig
"""The configuration of the renderer."""
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
@classmethod
def from_config(
cls,
renderer_config: RendererConfig,
*,
tokenizer: TokenizerLike | None = None,
):
if tokenizer is None and not renderer_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(renderer_config)
return cls(renderer_config, tokenizer)
@property
def model_config(self) -> ModelConfig:
"""The configuration of the model."""
return self.renderer_config.model_config
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
......@@ -1047,7 +1065,7 @@ class InputProcessingContext:
typ = ProcessorMixin
return cached_processor_from_config(
self.model_config,
self.renderer_config,
processor_cls=typ,
tokenizer=self.tokenizer,
**kwargs,
......
......@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from .cache import BaseMultiModalProcessorCache
from .processing import (
......@@ -22,7 +22,7 @@ from .profiling import (
)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__)
......@@ -114,17 +114,18 @@ class MultiModalRegistry:
return mm_options if len(mm_options) > 0 else None
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
def supports_multimodal_inputs(self, renderer_config: "RendererConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
Returns True if the model is multimodal with any non-zero supported
modalities, otherwise returns False, effectively running in
text-only mode.
"""
model_config = renderer_config.model_config
if not model_config.is_multimodal_model:
return False
info = self._create_processing_info(model_config, tokenizer=None)
info = self._create_processing_info(renderer_config, tokenizer=None)
supported_modalities = info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config()
......@@ -144,7 +145,7 @@ class MultiModalRegistry:
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
renderer_config: "RendererConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
......@@ -153,10 +154,11 @@ class MultiModalRegistry:
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
model_config = renderer_config.model_config
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
......@@ -171,7 +173,7 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
renderer_config: "RendererConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
......@@ -179,10 +181,11 @@ class MultiModalRegistry:
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
"""
model_config = renderer_config.model_config
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
......@@ -228,30 +231,21 @@ class MultiModalRegistry:
assert hasattr(model_cls, "_processor_factory")
return cast("SupportsMultiModal", model_cls)
def _create_processing_ctx(
self,
model_config: "ModelConfig",
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
renderer_config: "RendererConfig",
*,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
model_cls = self._get_model_cls(renderer_config.model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = InputProcessingContext.from_config(renderer_config, tokenizer=tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
renderer_config: "RendererConfig",
*,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
......@@ -259,19 +253,19 @@ class MultiModalRegistry:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
model_config = renderer_config.model_config
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = InputProcessingContext.from_config(renderer_config, tokenizer=tokenizer)
return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data(
self,
model_config: "ModelConfig",
renderer_config: "RendererConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
......@@ -280,15 +274,15 @@ class MultiModalRegistry:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
The model is identified by `renderer_config`.
"""
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
mm_options = self._extract_mm_options(renderer_config.model_config)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
......@@ -304,7 +298,7 @@ class MultiModalRegistry:
def get_encoder_dummy_data(
self,
model_config: "ModelConfig",
renderer_config: "RendererConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
......@@ -313,15 +307,15 @@ class MultiModalRegistry:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
The model is identified by `renderer_config`.
"""
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
mm_options = self._extract_mm_options(renderer_config.model_config)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
......@@ -336,13 +330,15 @@ class MultiModalRegistry:
return dummy_data
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
def get_encdec_max_encoder_len(self, renderer_config: "RendererConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
model_config = renderer_config.model_config
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
max_tokens = self.get_max_tokens_per_item_by_modality(renderer_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
......
......@@ -24,7 +24,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config import RendererConfig
logger = init_logger(__name__)
......@@ -205,18 +205,18 @@ def get_tokenizer(
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
def cached_tokenizer_from_config(renderer_config: "RendererConfig", **kwargs):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
renderer_config.tokenizer,
tokenizer_mode=renderer_config.tokenizer_mode,
revision=renderer_config.tokenizer_revision,
trust_remote_code=renderer_config.trust_remote_code,
**kwargs,
)
def init_tokenizer_from_config(model_config: "ModelConfig"):
runner_type = model_config.runner_type
def init_tokenizer_from_config(renderer_config: "RendererConfig"):
runner_type = renderer_config.model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
......@@ -225,9 +225,9 @@ def init_tokenizer_from_config(model_config: "ModelConfig"):
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
renderer_config.tokenizer,
tokenizer_mode=renderer_config.tokenizer_mode,
trust_remote_code=renderer_config.trust_remote_code,
revision=renderer_config.tokenizer_revision,
truncation_side=truncation_side,
)
......@@ -23,7 +23,7 @@ from vllm.transformers_utils.utils import convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config import ModelConfig, RendererConfig
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
_V = TypeVar("_V", bound=BaseVideoProcessor, default=BaseVideoProcessor)
......@@ -233,17 +233,18 @@ def cached_get_processor_without_dynamic_kwargs(
def cached_processor_from_config(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
model_config = renderer_config.model_config
if is_gguf(model_config.model):
assert not is_gguf(model_config.tokenizer), (
assert not is_gguf(renderer_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor."
)
model = model_config.tokenizer
revision = model_config.tokenizer_revision
model = renderer_config.tokenizer
revision = renderer_config.tokenizer_revision
else:
model = model_config.model
revision = model_config.revision
......@@ -297,9 +298,11 @@ cached_get_feature_extractor = lru_cache(get_feature_extractor)
def cached_feature_extractor_from_config(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
**kwargs: Any,
):
model_config = renderer_config.model_config
return cached_get_feature_extractor(
model_config.model,
revision=model_config.revision,
......@@ -348,16 +351,17 @@ cached_get_image_processor = lru_cache(get_image_processor)
def cached_image_processor_from_config(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
**kwargs: Any,
):
model_config = renderer_config.model_config
if is_gguf(model_config.model):
assert not is_gguf(model_config.tokenizer), (
assert not is_gguf(renderer_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor."
)
model = model_config.tokenizer
revision = model_config.tokenizer_revision
model = renderer_config.tokenizer
revision = renderer_config.tokenizer_revision
else:
model = model_config.model
revision = model_config.revision
......@@ -411,10 +415,12 @@ cached_get_video_processor = lru_cache(get_video_processor)
def cached_video_processor_from_config(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
processor_cls: type[_V] | None = None,
**kwargs: Any,
):
model_config = renderer_config.model_config
return cached_get_video_processor(
model_config.model,
revision=model_config.revision,
......
......@@ -10,7 +10,7 @@ from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config import RendererConfig, SchedulerConfig
logger = init_logger(__name__)
......@@ -250,7 +250,7 @@ class EncoderCacheManager:
def compute_encoder_budget(
model_config: "ModelConfig",
renderer_config: "RendererConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
......@@ -263,9 +263,9 @@ def compute_encoder_budget(
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
if mm_registry.supports_multimodal_inputs(renderer_config):
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
renderer_config
)
return compute_mm_encoder_budget(
......
......@@ -164,7 +164,7 @@ class Scheduler(SchedulerInterface):
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=vllm_config.model_config,
renderer_config=vllm_config.renderer_config,
scheduler_config=vllm_config.scheduler_config,
mm_registry=mm_registry,
)
......
......@@ -91,6 +91,7 @@ class AsyncLLM(EngineClient):
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.renderer_config = vllm_config.renderer_config
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
......@@ -108,15 +109,15 @@ class AsyncLLM(EngineClient):
"enabling logging without default stat loggers."
)
if self.model_config.skip_tokenizer_init:
if self.renderer_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_config(self.model_config)
tokenizer = init_tokenizer_from_config(self.renderer_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
self.renderer_config.io_processor_plugin,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
......
......@@ -43,6 +43,7 @@ class InputProcessor:
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
self.renderer_config = vllm_config.renderer_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
......@@ -54,7 +55,7 @@ class InputProcessor:
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.renderer_config,
tokenizer,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
......@@ -252,7 +253,7 @@ class InputProcessor:
if not params.structured_outputs or not self.structured_outputs_config:
return
if self.model_config.skip_tokenizer_init and params.structured_outputs:
if self.renderer_config.skip_tokenizer_init and params.structured_outputs:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
......@@ -582,7 +583,7 @@ class InputProcessor:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
self.renderer_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
......
......@@ -60,6 +60,7 @@ class LLMEngine:
) -> None:
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.renderer_config = vllm_config.renderer_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
......@@ -83,15 +84,15 @@ class LLMEngine:
self.dp_group = None
self.should_execute_dummy_batch = False
if self.model_config.skip_tokenizer_init:
if self.renderer_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_config(self.model_config)
tokenizer = init_tokenizer_from_config(self.renderer_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
self.renderer_config.io_processor_plugin,
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment