Unverified Commit 27f4c2fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Separate out `RendererConfig` from `ModelConfig` (#30145)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a49d813f
...@@ -18,7 +18,7 @@ from transformers.models.gemma3n import ( ...@@ -18,7 +18,7 @@ from transformers.models.gemma3n import (
) )
from transformers.models.siglip import SiglipImageProcessorFast from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -760,7 +760,7 @@ class Gemma3nForConditionalGeneration( ...@@ -760,7 +760,7 @@ class Gemma3nForConditionalGeneration(
cls, cls,
audio: np.ndarray, audio: np.ndarray,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, renderer_config: RendererConfig,
language: Optional[str], language: Optional[str],
task_type: Literal["transcribe", "translate"], task_type: Literal["transcribe", "translate"],
request_prompt: str, request_prompt: str,
...@@ -798,7 +798,9 @@ class Gemma3nForConditionalGeneration( ...@@ -798,7 +798,9 @@ class Gemma3nForConditionalGeneration(
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig: ) -> SpeechToTextConfig:
return SpeechToTextConfig( return SpeechToTextConfig(
# Let's set this to 30 as suggested in the docs for now, although # Let's set this to 30 as suggested in the docs for now, although
......
...@@ -34,7 +34,7 @@ import torch.nn.functional as F ...@@ -34,7 +34,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
...@@ -840,7 +840,7 @@ class GraniteSpeechForConditionalGeneration( ...@@ -840,7 +840,7 @@ class GraniteSpeechForConditionalGeneration(
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, audio: np.ndarray,
model_config: ModelConfig, renderer_config: RendererConfig,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
language: str | None, language: str | None,
task_type: Literal["transcribe", "translate"], task_type: Literal["transcribe", "translate"],
...@@ -861,7 +861,7 @@ class GraniteSpeechForConditionalGeneration( ...@@ -861,7 +861,7 @@ class GraniteSpeechForConditionalGeneration(
else: else:
raise ValueError(f"Unsupported task type {task_type}") raise ValueError(f"Unsupported task type {task_type}")
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(renderer_config)
chat = [dict(role="user", content=user_prompt)] chat = [dict(role="user", content=user_prompt)]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
chat, chat,
...@@ -882,10 +882,10 @@ class GraniteSpeechForConditionalGeneration( ...@@ -882,10 +882,10 @@ class GraniteSpeechForConditionalGeneration(
cls, cls,
audio_duration_s: float, audio_duration_s: float,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> int | None: ) -> int | None:
"""Get the number of audio tokens for an audio duration in sec.""" """Get the number of audio tokens for an audio duration in sec."""
processor = cached_processor_from_config(model_config) processor = cached_processor_from_config(renderer_config)
hop_length = processor.audio_processor.melspec_kwargs["hop_length"] hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
proj_win_size = processor.audio_processor.projector_window_size proj_win_size = processor.audio_processor.projector_window_size
ds_rate = processor.audio_processor.projector_downsample_rate ds_rate = processor.audio_processor.projector_downsample_rate
...@@ -903,7 +903,9 @@ class GraniteSpeechForConditionalGeneration( ...@@ -903,7 +903,9 @@ class GraniteSpeechForConditionalGeneration(
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig: ) -> SpeechToTextConfig:
"""Get the stt config for this model.""" """Get the stt config for this model."""
# Default settings are reasonable for this model and we don't currently # Default settings are reasonable for this model and we don't currently
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig from vllm.config import RendererConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
DispatchPooler, DispatchPooler,
...@@ -29,12 +29,12 @@ logger = init_logger(__name__) ...@@ -29,12 +29,12 @@ logger = init_logger(__name__)
class GritLMMeanPool(nn.Module): class GritLMMeanPool(nn.Module):
"""As `MeanPool`, but only includes non-instruction tokens.""" """As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig): def __init__(self, renderer_config: RendererConfig):
super().__init__() super().__init__()
self.model_config = model_config self.renderer_config = renderer_config
tokenizer = cached_tokenizer_from_config(self.model_config) tokenizer = cached_tokenizer_from_config(self.renderer_config)
# Collect the tokens needed for pattern matching. # Collect the tokens needed for pattern matching.
# "▁<" is different from "_<". The former uses "▁" to indicate that # "▁<" is different from "_<". The former uses "▁" to indicate that
...@@ -174,10 +174,10 @@ class GritLMMeanPool(nn.Module): ...@@ -174,10 +174,10 @@ class GritLMMeanPool(nn.Module):
class GritLMPooler(Pooler): class GritLMPooler(Pooler):
def __init__(self, model_config: ModelConfig): def __init__(self, renderer_config: RendererConfig):
super().__init__() super().__init__()
self.pooling = GritLMMeanPool(model_config) self.pooling = GritLMMeanPool(renderer_config)
self.head = PoolerHead(PoolerNormalize()) self.head = PoolerHead(PoolerNormalize())
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
...@@ -238,6 +238,6 @@ class GritLM(LlamaForCausalLM): ...@@ -238,6 +238,6 @@ class GritLM(LlamaForCausalLM):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"token_embed": Pooler.for_token_embed(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config), "embed": GritLMPooler(vllm_config.renderer_config),
} }
) )
...@@ -19,7 +19,7 @@ from torch import Tensor ...@@ -19,7 +19,7 @@ from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig from vllm.config import RendererConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -887,7 +887,7 @@ class SupportsTranscription(Protocol): ...@@ -887,7 +887,7 @@ class SupportsTranscription(Protocol):
cls, cls,
audio: np.ndarray, audio: np.ndarray,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, renderer_config: RendererConfig,
language: str | None, language: str | None,
task_type: Literal["transcribe", "translate"], task_type: Literal["transcribe", "translate"],
request_prompt: str, request_prompt: str,
...@@ -930,7 +930,9 @@ class SupportsTranscription(Protocol): ...@@ -930,7 +930,9 @@ class SupportsTranscription(Protocol):
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"] cls,
renderer_config: RendererConfig,
task_type: Literal["transcribe", "translate"],
) -> SpeechToTextConfig: ) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model.""" """Get the speech to text config for the ASR model."""
... ...
...@@ -940,7 +942,7 @@ class SupportsTranscription(Protocol): ...@@ -940,7 +942,7 @@ class SupportsTranscription(Protocol):
cls, cls,
audio_duration_s: float, audio_duration_s: float,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> int | None: ) -> int | None:
""" """
Map from audio duration to number of audio tokens produced by the ASR Map from audio duration to number of audio tokens produced by the ASR
......
...@@ -182,7 +182,7 @@ class InternS1ProcessingInfo(BaseProcessingInfo): ...@@ -182,7 +182,7 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs) hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
hf_processor.video_processor = cached_video_processor_from_config( hf_processor.video_processor = cached_video_processor_from_config(
self.ctx.model_config, self.ctx.renderer_config,
processor_cls=InternVLVideoProcessor, processor_cls=InternVLVideoProcessor,
size=hf_processor.image_processor.size, size=hf_processor.image_processor.size,
**kwargs, **kwargs,
......
...@@ -1169,16 +1169,17 @@ class NemotronH_Nano_VL_V2( ...@@ -1169,16 +1169,17 @@ class NemotronH_Nano_VL_V2(
self.mlp1 = self.mlp1.to(self.language_model.config.dtype) self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
self.config = config self.config = config
self.model_config = vllm_config.model_config
# Pre-tokenize special tokens for video processing # Pre-tokenize special tokens for video processing
# to avoid repeated tokenization # to avoid repeated tokenization
tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self._tokenizer = cached_tokenizer_from_config(vllm_config.renderer_config)
self._img_start_token_ids = tokenizer.encode( self._img_start_token_ids = self._tokenizer.encode(
IMG_START, add_special_tokens=False IMG_START, add_special_tokens=False
) )
self._img_end_token_ids = tokenizer.encode(IMG_END, add_special_tokens=False) self._img_end_token_ids = self._tokenizer.encode(
self._img_context_token_ids = tokenizer.encode( IMG_END, add_special_tokens=False
)
self._img_context_token_ids = self._tokenizer.encode(
IMG_CONTEXT, add_special_tokens=False IMG_CONTEXT, add_special_tokens=False
) )
...@@ -1364,7 +1365,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1364,7 +1365,7 @@ class NemotronH_Nano_VL_V2(
input_embeds for the LLM. input_embeds for the LLM.
""" """
device = video_embeddings.device device = video_embeddings.device
tokenizer = cached_tokenizer_from_config(self.model_config) tokenizer = self._tokenizer
# Generate video replacement token IDs using get_video_repl # Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized # This tokenizes each frame separator independently, then uses pre-tokenized
......
...@@ -347,7 +347,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -347,7 +347,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
def get_image_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object):
return cached_image_processor_from_config( return cached_image_processor_from_config(
self.ctx.model_config, self.ctx.renderer_config,
**kwargs, **kwargs,
) )
......
...@@ -193,7 +193,7 @@ class PixtralProcessorAdapter: ...@@ -193,7 +193,7 @@ class PixtralProcessorAdapter:
class PixtralProcessingInfo(BaseProcessingInfo): class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer: def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config) tokenizer = cached_tokenizer_from_config(self.ctx.renderer_config)
if not isinstance(tokenizer, MistralTokenizer): if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`") raise ValueError("This model requires `--tokenizer-mode mistral`")
......
...@@ -20,7 +20,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder ...@@ -20,7 +20,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
from transformers import BatchFeature, TensorType, WhisperConfig from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -176,7 +176,7 @@ class VoxtralProcessorAdapter: ...@@ -176,7 +176,7 @@ class VoxtralProcessorAdapter:
class VoxtralProcessingInfo(BaseProcessingInfo): class VoxtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer: def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config) tokenizer = cached_tokenizer_from_config(self.ctx.renderer_config)
if not isinstance(tokenizer, MistralTokenizer): if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`") raise ValueError("This model requires `--tokenizer-mode mistral`")
...@@ -339,7 +339,7 @@ class VoxtralForConditionalGeneration( ...@@ -339,7 +339,7 @@ class VoxtralForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.tokenizer = cached_tokenizer_from_config(vllm_config.renderer_config)
# update quant config to so that ignored module and target module names # update quant config to so that ignored module and target module names
# match the vLLM model names # match the vLLM model names
...@@ -450,9 +450,11 @@ class VoxtralForConditionalGeneration( ...@@ -450,9 +450,11 @@ class VoxtralForConditionalGeneration(
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig: ) -> SpeechToTextConfig:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(renderer_config)
audio_config = tokenizer.instruct.audio_encoder.audio_config audio_config = tokenizer.instruct.audio_encoder.audio_config
max_audio_clip_s = audio_config.chunk_length_s max_audio_clip_s = audio_config.chunk_length_s
sample_rate = audio_config.sampling_rate sample_rate = audio_config.sampling_rate
...@@ -468,17 +470,17 @@ class VoxtralForConditionalGeneration( ...@@ -468,17 +470,17 @@ class VoxtralForConditionalGeneration(
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, audio: np.ndarray,
model_config: ModelConfig, renderer_config: RendererConfig, # not needed here
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
language: str | None, language: str | None,
task_type: Literal["transcribe", "translate"], task_type: Literal["transcribe", "translate"],
request_prompt: str, request_prompt: str,
to_language: str | None, to_language: str | None,
) -> PromptType: ) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(renderer_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest( req = TranscriptionRequest(
model=model_config.model, model=renderer_config.model_config.model,
audio=RawAudio.from_audio(audio), audio=RawAudio.from_audio(audio),
language=language, language=language,
) )
...@@ -494,14 +496,14 @@ class VoxtralForConditionalGeneration( ...@@ -494,14 +496,14 @@ class VoxtralForConditionalGeneration(
cls, cls,
audio_duration_s: float, audio_duration_s: float,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> int | None: ) -> int | None:
""" """
Map from audio duration to number of audio tokens produced by the ASR Map from audio duration to number of audio tokens produced by the ASR
model, without running a forward pass. model, without running a forward pass.
This is used for estimating the amount of processing for this audio. This is used for estimating the amount of processing for this audio.
""" """
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(renderer_config)
adapter = VoxtralProcessorAdapter(tokenizer) adapter = VoxtralProcessorAdapter(tokenizer)
return adapter.get_num_audio_tokens( return adapter.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate) int(audio_duration_s * stt_config.sample_rate)
......
...@@ -19,7 +19,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids ...@@ -19,7 +19,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, RendererConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
...@@ -811,7 +811,7 @@ class WhisperForConditionalGeneration( ...@@ -811,7 +811,7 @@ class WhisperForConditionalGeneration(
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, audio: np.ndarray,
model_config: ModelConfig, # not needed here renderer_config: RendererConfig, # not needed here
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
language: str | None, language: str | None,
task_type: Literal["transcribe", "translate"], task_type: Literal["transcribe", "translate"],
...@@ -847,9 +847,11 @@ class WhisperForConditionalGeneration( ...@@ -847,9 +847,11 @@ class WhisperForConditionalGeneration(
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls,
renderer_config: RendererConfig,
task_type: str,
) -> SpeechToTextConfig: ) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config) processor = cached_processor_from_config(renderer_config)
return SpeechToTextConfig( return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length, max_audio_clip_s=processor.feature_extractor.chunk_length,
...@@ -861,9 +863,9 @@ class WhisperForConditionalGeneration( ...@@ -861,9 +863,9 @@ class WhisperForConditionalGeneration(
cls, cls,
audio_duration_s: float, audio_duration_s: float,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, renderer_config: RendererConfig,
) -> int | None: ) -> int | None:
processor = cached_processor_from_config(model_config) processor = cached_processor_from_config(renderer_config)
hop_length = processor.feature_extractor.hop_length hop_length = processor.feature_extractor.hop_length
assert hop_length is not None assert hop_length is not None
# NOTE(NickLucche) user can't pass encoder # NOTE(NickLucche) user can't pass encoder
......
...@@ -31,7 +31,7 @@ from .inputs import ( ...@@ -31,7 +31,7 @@ from .inputs import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, RendererConfig, VllmConfig
from .processing import ResolvedPromptUpdate from .processing import ResolvedPromptUpdate
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
...@@ -561,13 +561,13 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): ...@@ -561,13 +561,13 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
def _enable_processor_cache( def _enable_processor_cache(
model_config: "ModelConfig", renderer_config: "RendererConfig",
mm_registry: "MultiModalRegistry", mm_registry: "MultiModalRegistry",
) -> bool: ) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config): if not mm_registry.supports_multimodal_inputs(renderer_config):
return False return False
mm_config = model_config.get_multimodal_config() mm_config = renderer_config.model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0 return mm_config.mm_processor_cache_gb > 0
...@@ -599,7 +599,7 @@ def processor_cache_from_config( ...@@ -599,7 +599,7 @@ def processor_cache_from_config(
"""Return a `BaseMultiModalProcessorCache`, if enabled.""" """Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry): if not _enable_processor_cache(vllm_config.renderer_config, mm_registry):
return None return None
if not _enable_ipc_cache(vllm_config): if not _enable_ipc_cache(vllm_config):
...@@ -611,14 +611,14 @@ def processor_cache_from_config( ...@@ -611,14 +611,14 @@ def processor_cache_from_config(
def processor_only_cache_from_config( def processor_only_cache_from_config(
model_config: "ModelConfig", renderer_config: "RendererConfig",
mm_registry: "MultiModalRegistry", mm_registry: "MultiModalRegistry",
): ):
"""Return a `MultiModalProcessorOnlyCache`, if enabled.""" """Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry): if not _enable_processor_cache(renderer_config, mm_registry):
return None return None
return MultiModalProcessorOnlyCache(model_config) return MultiModalProcessorOnlyCache(renderer_config.model_config)
class BaseMultiModalReceiverCache( class BaseMultiModalReceiverCache(
...@@ -787,7 +787,7 @@ def engine_receiver_cache_from_config( ...@@ -787,7 +787,7 @@ def engine_receiver_cache_from_config(
""" """
model_config = vllm_config.model_config model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry): if not _enable_processor_cache(vllm_config.renderer_config, mm_registry):
return None return None
if not _enable_ipc_cache(vllm_config): if not _enable_ipc_cache(vllm_config):
...@@ -809,9 +809,7 @@ def worker_receiver_cache_from_config( ...@@ -809,9 +809,7 @@ def worker_receiver_cache_from_config(
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="shm". mm_processor_cache_type=="shm".
""" """
model_config = vllm_config.model_config if not _enable_processor_cache(vllm_config.renderer_config, mm_registry):
if not _enable_processor_cache(model_config, mm_registry):
return None return None
if not _enable_ipc_cache(vllm_config): if not _enable_ipc_cache(vllm_config):
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
...@@ -53,7 +53,7 @@ if TYPE_CHECKING: ...@@ -53,7 +53,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig from vllm.config import ModelConfig, RendererConfig
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder
...@@ -63,6 +63,7 @@ else: ...@@ -63,6 +63,7 @@ else:
ProcessorMixin = object ProcessorMixin = object
ModelConfig = object ModelConfig = object
RendererConfig = object
BaseMultiModalProcessorCache = object BaseMultiModalProcessorCache = object
...@@ -945,12 +946,29 @@ class InputProcessingContext: ...@@ -945,12 +946,29 @@ class InputProcessingContext:
modify the inputs. modify the inputs.
""" """
model_config: ModelConfig renderer_config: RendererConfig
"""The configuration of the model.""" """The configuration of the renderer."""
tokenizer: TokenizerLike | None tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
@classmethod
def from_config(
cls,
renderer_config: RendererConfig,
*,
tokenizer: TokenizerLike | None = None,
):
if tokenizer is None and not renderer_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(renderer_config)
return cls(renderer_config, tokenizer)
@property
def model_config(self) -> ModelConfig:
"""The configuration of the model."""
return self.renderer_config.model_config
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
...@@ -1047,7 +1065,7 @@ class InputProcessingContext: ...@@ -1047,7 +1065,7 @@ class InputProcessingContext:
typ = ProcessorMixin typ = ProcessorMixin
return cached_processor_from_config( return cached_processor_from_config(
self.model_config, self.renderer_config,
processor_cls=typ, processor_cls=typ,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
**kwargs, **kwargs,
......
...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast ...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .processing import ( from .processing import (
...@@ -22,7 +22,7 @@ from .profiling import ( ...@@ -22,7 +22,7 @@ from .profiling import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig, RendererConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -114,17 +114,18 @@ class MultiModalRegistry: ...@@ -114,17 +114,18 @@ class MultiModalRegistry:
return mm_options if len(mm_options) > 0 else None return mm_options if len(mm_options) > 0 else None
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: def supports_multimodal_inputs(self, renderer_config: "RendererConfig") -> bool:
""" """
Checks if the model supports multimodal inputs. Checks if the model supports multimodal inputs.
Returns True if the model is multimodal with any non-zero supported Returns True if the model is multimodal with any non-zero supported
modalities, otherwise returns False, effectively running in modalities, otherwise returns False, effectively running in
text-only mode. text-only mode.
""" """
model_config = renderer_config.model_config
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return False return False
info = self._create_processing_info(model_config, tokenizer=None) info = self._create_processing_info(renderer_config, tokenizer=None)
supported_modalities = info.get_supported_mm_limits() supported_modalities = info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config() mm_config = model_config.get_multimodal_config()
...@@ -144,7 +145,7 @@ class MultiModalRegistry: ...@@ -144,7 +145,7 @@ class MultiModalRegistry:
def get_max_tokens_per_item_by_modality( def get_max_tokens_per_item_by_modality(
self, self,
model_config: "ModelConfig", renderer_config: "RendererConfig",
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None, profiler_limits: Mapping[str, int] | None = None,
...@@ -153,10 +154,11 @@ class MultiModalRegistry: ...@@ -153,10 +154,11 @@ class MultiModalRegistry:
Get the maximum number of tokens per data item from each modality based Get the maximum number of tokens per data item from each modality based
on underlying model configuration. on underlying model configuration.
""" """
model_config = renderer_config.model_config
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
...@@ -171,7 +173,7 @@ class MultiModalRegistry: ...@@ -171,7 +173,7 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt( def get_mm_limits_per_prompt(
self, self,
model_config: "ModelConfig", renderer_config: "RendererConfig",
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
...@@ -179,10 +181,11 @@ class MultiModalRegistry: ...@@ -179,10 +181,11 @@ class MultiModalRegistry:
Get the maximum number of multi-modal input instances for each modality Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class. that are allowed per prompt for a model class.
""" """
model_config = renderer_config.model_config
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits() return profiler.get_mm_limits()
...@@ -228,30 +231,21 @@ class MultiModalRegistry: ...@@ -228,30 +231,21 @@ class MultiModalRegistry:
assert hasattr(model_cls, "_processor_factory") assert hasattr(model_cls, "_processor_factory")
return cast("SupportsMultiModal", model_cls) return cast("SupportsMultiModal", model_cls)
def _create_processing_ctx(
self,
model_config: "ModelConfig",
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info( def _create_processing_info(
self, self,
model_config: "ModelConfig", renderer_config: "RendererConfig",
*, *,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo: ) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(renderer_config.model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = InputProcessingContext.from_config(renderer_config, tokenizer=tokenizer)
return factories.info(ctx) return factories.info(ctx)
def create_processor( def create_processor(
self, self,
model_config: "ModelConfig", renderer_config: "RendererConfig",
*, *,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
...@@ -259,19 +253,19 @@ class MultiModalRegistry: ...@@ -259,19 +253,19 @@ class MultiModalRegistry:
""" """
Create a multi-modal processor for a specific model and tokenizer. Create a multi-modal processor for a specific model and tokenizer.
""" """
model_config = renderer_config.model_config
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model") raise ValueError(f"{model_config.model} is not a multimodal model")
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer) ctx = InputProcessingContext.from_config(renderer_config, tokenizer=tokenizer)
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data( def get_decoder_dummy_data(
self, self,
model_config: "ModelConfig", renderer_config: "RendererConfig",
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
*, *,
...@@ -280,15 +274,15 @@ class MultiModalRegistry: ...@@ -280,15 +274,15 @@ class MultiModalRegistry:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`. The model is identified by `renderer_config`.
""" """
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config. # Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy # Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged. # count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config) mm_options = self._extract_mm_options(renderer_config.model_config)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
...@@ -304,7 +298,7 @@ class MultiModalRegistry: ...@@ -304,7 +298,7 @@ class MultiModalRegistry:
def get_encoder_dummy_data( def get_encoder_dummy_data(
self, self,
model_config: "ModelConfig", renderer_config: "RendererConfig",
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
*, *,
...@@ -313,15 +307,15 @@ class MultiModalRegistry: ...@@ -313,15 +307,15 @@ class MultiModalRegistry:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`. The model is identified by `renderer_config`.
""" """
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(renderer_config, cache=cache)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config. # Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy # Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged. # count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config) mm_options = self._extract_mm_options(renderer_config.model_config)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
...@@ -336,13 +330,15 @@ class MultiModalRegistry: ...@@ -336,13 +330,15 @@ class MultiModalRegistry:
return dummy_data return dummy_data
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: def get_encdec_max_encoder_len(self, renderer_config: "RendererConfig") -> int:
""" """
Get the maximum length of the encoder input for encoder-decoder models. Get the maximum length of the encoder input for encoder-decoder models.
""" """
model_config = renderer_config.model_config
if not model_config.is_encoder_decoder: if not model_config.is_encoder_decoder:
return 0 return 0
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
max_tokens = self.get_max_tokens_per_item_by_modality(renderer_config)
if not max_tokens: if not max_tokens:
# TODO - this function assumes encoder-decoder models are # TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more # multimodal. This will need to change when adding support for more
......
...@@ -24,7 +24,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -24,7 +24,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike from .protocol import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import RendererConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -205,18 +205,18 @@ def get_tokenizer( ...@@ -205,18 +205,18 @@ def get_tokenizer(
cached_get_tokenizer = lru_cache(get_tokenizer) cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): def cached_tokenizer_from_config(renderer_config: "RendererConfig", **kwargs):
return cached_get_tokenizer( return cached_get_tokenizer(
model_config.tokenizer, renderer_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=renderer_config.tokenizer_mode,
revision=model_config.tokenizer_revision, revision=renderer_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=renderer_config.trust_remote_code,
**kwargs, **kwargs,
) )
def init_tokenizer_from_config(model_config: "ModelConfig"): def init_tokenizer_from_config(renderer_config: "RendererConfig"):
runner_type = model_config.runner_type runner_type = renderer_config.model_config.runner_type
if runner_type == "generate" or runner_type == "draft": if runner_type == "generate" or runner_type == "draft":
truncation_side = "left" truncation_side = "left"
elif runner_type == "pooling": elif runner_type == "pooling":
...@@ -225,9 +225,9 @@ def init_tokenizer_from_config(model_config: "ModelConfig"): ...@@ -225,9 +225,9 @@ def init_tokenizer_from_config(model_config: "ModelConfig"):
assert_never(runner_type) assert_never(runner_type)
return get_tokenizer( return get_tokenizer(
model_config.tokenizer, renderer_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=renderer_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=renderer_config.trust_remote_code,
revision=model_config.tokenizer_revision, revision=renderer_config.tokenizer_revision,
truncation_side=truncation_side, truncation_side=truncation_side,
) )
...@@ -23,7 +23,7 @@ from vllm.transformers_utils.utils import convert_model_repo_to_path ...@@ -23,7 +23,7 @@ from vllm.transformers_utils.utils import convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig, RendererConfig
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) _P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
_V = TypeVar("_V", bound=BaseVideoProcessor, default=BaseVideoProcessor) _V = TypeVar("_V", bound=BaseVideoProcessor, default=BaseVideoProcessor)
...@@ -233,17 +233,18 @@ def cached_get_processor_without_dynamic_kwargs( ...@@ -233,17 +233,18 @@ def cached_get_processor_without_dynamic_kwargs(
def cached_processor_from_config( def cached_processor_from_config(
model_config: "ModelConfig", renderer_config: "RendererConfig",
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any, **kwargs: Any,
) -> _P: ) -> _P:
model_config = renderer_config.model_config
if is_gguf(model_config.model): if is_gguf(model_config.model):
assert not is_gguf(model_config.tokenizer), ( assert not is_gguf(renderer_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer " "For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor." "should be used to correctly load processor."
) )
model = model_config.tokenizer model = renderer_config.tokenizer
revision = model_config.tokenizer_revision revision = renderer_config.tokenizer_revision
else: else:
model = model_config.model model = model_config.model
revision = model_config.revision revision = model_config.revision
...@@ -297,9 +298,11 @@ cached_get_feature_extractor = lru_cache(get_feature_extractor) ...@@ -297,9 +298,11 @@ cached_get_feature_extractor = lru_cache(get_feature_extractor)
def cached_feature_extractor_from_config( def cached_feature_extractor_from_config(
model_config: "ModelConfig", renderer_config: "RendererConfig",
**kwargs: Any, **kwargs: Any,
): ):
model_config = renderer_config.model_config
return cached_get_feature_extractor( return cached_get_feature_extractor(
model_config.model, model_config.model,
revision=model_config.revision, revision=model_config.revision,
...@@ -348,16 +351,17 @@ cached_get_image_processor = lru_cache(get_image_processor) ...@@ -348,16 +351,17 @@ cached_get_image_processor = lru_cache(get_image_processor)
def cached_image_processor_from_config( def cached_image_processor_from_config(
model_config: "ModelConfig", renderer_config: "RendererConfig",
**kwargs: Any, **kwargs: Any,
): ):
model_config = renderer_config.model_config
if is_gguf(model_config.model): if is_gguf(model_config.model):
assert not is_gguf(model_config.tokenizer), ( assert not is_gguf(renderer_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer " "For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor." "should be used to correctly load image processor."
) )
model = model_config.tokenizer model = renderer_config.tokenizer
revision = model_config.tokenizer_revision revision = renderer_config.tokenizer_revision
else: else:
model = model_config.model model = model_config.model
revision = model_config.revision revision = model_config.revision
...@@ -411,10 +415,12 @@ cached_get_video_processor = lru_cache(get_video_processor) ...@@ -411,10 +415,12 @@ cached_get_video_processor = lru_cache(get_video_processor)
def cached_video_processor_from_config( def cached_video_processor_from_config(
model_config: "ModelConfig", renderer_config: "RendererConfig",
processor_cls: type[_V] | None = None, processor_cls: type[_V] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
model_config = renderer_config.model_config
return cached_get_video_processor( return cached_get_video_processor(
model_config.model, model_config.model,
revision=model_config.revision, revision=model_config.revision,
......
...@@ -10,7 +10,7 @@ from vllm.multimodal import MultiModalRegistry ...@@ -10,7 +10,7 @@ from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request from vllm.v1.request import Request
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import RendererConfig, SchedulerConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -250,7 +250,7 @@ class EncoderCacheManager: ...@@ -250,7 +250,7 @@ class EncoderCacheManager:
def compute_encoder_budget( def compute_encoder_budget(
model_config: "ModelConfig", renderer_config: "RendererConfig",
scheduler_config: "SchedulerConfig", scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry, mm_registry: MultiModalRegistry,
) -> tuple[int, int]: ) -> tuple[int, int]:
...@@ -263,9 +263,9 @@ def compute_encoder_budget( ...@@ -263,9 +263,9 @@ def compute_encoder_budget(
- Space budget for encoder cache size, measured in number of tokens - Space budget for encoder cache size, measured in number of tokens
from the input sequence. from the input sequence.
""" """
if mm_registry.supports_multimodal_inputs(model_config): if mm_registry.supports_multimodal_inputs(renderer_config):
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config renderer_config
) )
return compute_mm_encoder_budget( return compute_mm_encoder_budget(
......
...@@ -164,7 +164,7 @@ class Scheduler(SchedulerInterface): ...@@ -164,7 +164,7 @@ class Scheduler(SchedulerInterface):
# This can be changed when we make encoder cache for embedding caching # This can be changed when we make encoder cache for embedding caching
# across requests. # across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=vllm_config.model_config, renderer_config=vllm_config.renderer_config,
scheduler_config=vllm_config.scheduler_config, scheduler_config=vllm_config.scheduler_config,
mm_registry=mm_registry, mm_registry=mm_registry,
) )
......
...@@ -91,6 +91,7 @@ class AsyncLLM(EngineClient): ...@@ -91,6 +91,7 @@ class AsyncLLM(EngineClient):
# Ensure we can serialize custom transformer configs # Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value() maybe_register_config_serialize_by_value()
self.renderer_config = vllm_config.renderer_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
...@@ -108,15 +109,15 @@ class AsyncLLM(EngineClient): ...@@ -108,15 +109,15 @@ class AsyncLLM(EngineClient):
"enabling logging without default stat loggers." "enabling logging without default stat loggers."
) )
if self.model_config.skip_tokenizer_init: if self.renderer_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = init_tokenizer_from_config(self.model_config) tokenizer = init_tokenizer_from_config(self.renderer_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor( self.io_processor = get_io_processor(
self.vllm_config, self.vllm_config,
self.model_config.io_processor_plugin, self.renderer_config.io_processor_plugin,
) )
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput). # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
......
...@@ -43,6 +43,7 @@ class InputProcessor: ...@@ -43,6 +43,7 @@ class InputProcessor:
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.renderer_config = vllm_config.renderer_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
...@@ -54,7 +55,7 @@ class InputProcessor: ...@@ -54,7 +55,7 @@ class InputProcessor:
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
self.model_config, self.renderer_config,
tokenizer, tokenizer,
mm_registry, mm_registry,
mm_processor_cache=self.mm_processor_cache, mm_processor_cache=self.mm_processor_cache,
...@@ -252,7 +253,7 @@ class InputProcessor: ...@@ -252,7 +253,7 @@ class InputProcessor:
if not params.structured_outputs or not self.structured_outputs_config: if not params.structured_outputs or not self.structured_outputs_config:
return return
if self.model_config.skip_tokenizer_init and params.structured_outputs: if self.renderer_config.skip_tokenizer_init and params.structured_outputs:
raise ValueError( raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
) )
...@@ -582,7 +583,7 @@ class InputProcessor: ...@@ -582,7 +583,7 @@ class InputProcessor:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor( mm_processor = mm_registry.create_processor(
model_config, self.renderer_config,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
assert isinstance(mm_processor, EncDecMultiModalProcessor) assert isinstance(mm_processor, EncDecMultiModalProcessor)
......
...@@ -60,6 +60,7 @@ class LLMEngine: ...@@ -60,6 +60,7 @@ class LLMEngine:
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.renderer_config = vllm_config.renderer_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
...@@ -83,15 +84,15 @@ class LLMEngine: ...@@ -83,15 +84,15 @@ class LLMEngine:
self.dp_group = None self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
if self.model_config.skip_tokenizer_init: if self.renderer_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = init_tokenizer_from_config(self.model_config) tokenizer = init_tokenizer_from_config(self.renderer_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor( self.io_processor = get_io_processor(
self.vllm_config, self.vllm_config,
self.model_config.io_processor_plugin, self.renderer_config.io_processor_plugin,
) )
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment