Unverified Commit 73391a1b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (1/2) (#34510)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent b3c14229
......@@ -54,6 +54,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......@@ -67,7 +68,7 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
......
......@@ -53,6 +53,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......@@ -78,7 +79,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
......
......@@ -52,6 +52,7 @@ class MockModelConfig:
generation_config: str = "auto"
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......@@ -95,7 +96,7 @@ def register_mock_resolver():
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
......
......@@ -529,6 +529,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......@@ -542,7 +543,7 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
return HfRenderer.from_config(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
......@@ -756,9 +757,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
MockVllmConfig(mock_engine.model_config),
tokenizer_kwargs={},
tokenizer=mock_tokenizer,
)
mock_renderer._tokenizer = mock_tokenizer
# Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window.
......@@ -798,9 +798,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
MockVllmConfig(mock_engine.model_config),
tokenizer_kwargs={},
tokenizer=mock_tokenizer,
)
mock_renderer._tokenizer = mock_tokenizer
# prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest).
mock_renderer.render_messages_async = AsyncMock(
......
......@@ -38,6 +38,7 @@ class MockModelConfig:
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
@dataclass
......@@ -78,15 +79,16 @@ def _build_renderer(
renderer = HfRenderer(
MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
tokenizer=(
None
if model_config.skip_tokenizer_init
else DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
),
)
if not model_config.skip_tokenizer_init:
renderer._tokenizer = DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
return renderer
......@@ -277,7 +279,7 @@ class TestRenderPrompt:
)
# Should not even attempt tokenization
assert renderer._tokenizer._captured_encode_kwargs == {}
assert renderer.tokenizer._captured_encode_kwargs == {}
def test_text_max_length_exceeded_nonobvious(self):
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)
......@@ -298,8 +300,8 @@ class TestRenderPrompt:
)
# Should only tokenize the first max_total_tokens + 1 tokens
assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101
assert renderer.tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer.tokenizer._captured_encode_kwargs["max_length"] == 101
def test_token_max_length_exceeded(self):
renderer = _build_renderer(MockModelConfig())
......
......@@ -36,6 +36,7 @@ class MockModelConfig:
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
@dataclass
......@@ -57,9 +58,8 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(
MockVllmConfig(mock_model_config),
tokenizer_kwargs={},
tokenizer=mock_tokenizer,
)
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams())
......
......@@ -19,7 +19,7 @@ import pytest
import pytest_asyncio
from vllm import SamplingParams
from vllm.inputs import StreamingInput
from vllm.engine.protocol import StreamingInput
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
......
......@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.inputs import StreamingInput
from vllm.engine.protocol import StreamingInput
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
......
......@@ -18,7 +18,7 @@ import dataclasses
import json
import time
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING, Any
import numpy as np
......@@ -28,9 +28,6 @@ from vllm.benchmarks.datasets import (
)
from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing.context import (
get_timing_stats_from_engine_client,
)
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule
......@@ -39,16 +36,103 @@ try:
except ImportError:
pd = PlaceholderModule("pandas")
if TYPE_CHECKING: # Avoid having to mock during docs build
from vllm.v1.engine.llm_engine import LLMEngine
else:
LLMEngine = object
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
"""
Get all multimodal timing stats from the LLM engine.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args:
llm_engine: The LLM engine (has input_processor and workers).
Returns:
Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'num_encoder_calls': 1
}
}
"""
observability_config = llm_engine.vllm_config.observability_config
if not observability_config or not observability_config.enable_mm_processor_stats:
return {}
renderer = llm_engine.renderer
mm_processor = renderer.get_mm_processor()
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in preprocessing_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
def collect_mm_processor_stats(
llm_engine: Any,
llm_engine: LLMEngine,
num_warmup_reqs: int = 0,
) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine_client(llm_engine)
all_stats = get_timing_stats_from_engine(llm_engine)
stat_keys = [
"hf_processor_time",
......
......@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from vllm.config import ModelConfig, VllmConfig
......@@ -10,7 +11,7 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import PromptType, StreamingInput
from vllm.inputs.data import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
......@@ -26,6 +27,18 @@ if TYPE_CHECKING:
from vllm.v1.engine import PauseMode
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""
......
......@@ -72,7 +72,7 @@ from vllm.outputs import (
)
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers import ChatParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
......@@ -384,7 +384,7 @@ class LLM:
return parallel_config.world_size
def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
self.renderer.clear_mm_cache()
self.llm_engine.reset_mm_cache()
def get_default_sampling_params(self) -> SamplingParams:
......@@ -876,19 +876,6 @@ class LLM:
return outputs
def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _preprocess_cmpl(
self,
prompts: Sequence[PromptType],
......@@ -910,20 +897,12 @@ class LLM:
parsed_prompts = [
parse_model_prompt(model_config, prompt) for prompt in prompts
]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(parsed_prompts, tok_params)
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
).with_kwargs(tokenization_kwargs)
def _preprocess_chat(
self,
conversations: Sequence[list[ChatCompletionMessageParam]],
......@@ -961,7 +940,9 @@ class LLM:
),
),
)
tok_params = self._get_chat_tok_params(tokenization_kwargs)
tok_params = renderer.default_chat_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
_, engine_prompts = renderer.render_chat(
conversations,
......@@ -1653,7 +1634,10 @@ class LLM:
architecture=architecture,
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder:
......@@ -1970,7 +1954,10 @@ class LLM:
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
......
......@@ -8,11 +8,11 @@ from typing import Literal, cast
import numpy as np
from vllm.engine.protocol import EngineClient
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType, StreamingInput
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime
......
......@@ -12,7 +12,6 @@ from .data import (
PromptType,
SingletonInputs,
SingletonPrompt,
StreamingInput,
TextPrompt,
TokenInputs,
TokensPrompt,
......@@ -36,5 +35,4 @@ __all__ = [
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"StreamingInput",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch
from typing_extensions import NotRequired, TypedDict
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import (
MultiModalDataDict,
......@@ -299,15 +296,3 @@ which can be passed to
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
"""The inputs for a single encoder/decoder prompt."""
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
......@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
......@@ -28,8 +26,6 @@ from vllm.renderers.inputs import (
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import (
DecoderInputs,
......@@ -57,17 +53,12 @@ class InputPreprocessor:
vllm_config: VllmConfig,
renderer: BaseRenderer | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
super().__init__()
self.model_config = vllm_config.model_config
self.observability_config = vllm_config.observability_config
self.renderer = renderer or renderer_from_config(vllm_config)
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
@property
def tokenizer(self) -> TokenizerLike | None:
......@@ -124,23 +115,6 @@ class InputPreprocessor:
return decoder_input_ids
def _get_tokenization_kw(
self,
overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
kwargs = dict[str, Any]()
if self.model_config.is_encoder_decoder:
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs["add_special_tokens"] = False
if overrides:
kwargs.update(overrides)
return kwargs
def _tokenize_prompt(
self,
prompt: str,
......@@ -150,26 +124,18 @@ class InputPreprocessor:
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config
renderer = self.renderer
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
return tokenizer.encode(prompt, **tokenization_kwargs)
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
def _get_mm_processor(self) -> BaseMultiModalProcessor:
if not hasattr(self, "_mm_processor"):
self._mm_processor = self.mm_registry.create_processor(
self.model_config,
self.observability_config,
tokenizer=self.tokenizer,
cache=self.mm_processor_cache,
)
tok_prompt = renderer.tokenize_prompt(
TextPrompt(prompt=prompt),
tok_params,
)
return self._mm_processor
return tok_prompt["prompt_token_ids"]
def _process_multimodal(
self,
......@@ -184,33 +150,20 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
mm_processor = self._get_mm_processor()
mm_processor = self.renderer.get_mm_processor()
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
mm_items = mm_processor.info.parse_mm_data(mm_data)
mm_input = mm_processor.apply(
return mm_processor.apply(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
contains_only_strings = all(
isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
)
if not contains_only_strings:
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method."
)
return mm_input
def _process_embeds(
self,
......@@ -245,19 +198,18 @@ class InputPreprocessor:
def _truncate_inputs(
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
) -> list[int]:
if (
not tokenization_kwargs
or "truncation" not in tokenization_kwargs
or self.tokenizer is None
):
return inputs
renderer = self.renderer
max_length = tokenization_kwargs["max_length"]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
if self.tokenizer.truncation_side == "left":
return inputs[-max_length:]
else:
return inputs[:max_length]
tok_prompt = renderer.tokenize_prompt(
TokensPrompt(prompt_token_ids=inputs),
tok_params,
)
return tok_prompt["prompt_token_ids"]
def _process_tokens(
self,
......@@ -539,26 +491,6 @@ class InputPreprocessor:
"""Preprocess the input prompt."""
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
if self.mm_processor_cache and self.mm_cache_stats is not None:
delta = self.mm_processor_cache.make_stats(delta=True)
self.mm_cache_stats.requests += 1
self.mm_cache_stats.queries += delta.total
self.mm_cache_stats.hits += delta.hits
self.renderer.update_mm_cache_stats()
return res
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self.mm_cache_stats
if mm_cache_stats is None:
return None
self.mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def clear_mm_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
if self.mm_cache_stats is not None:
self.mm_cache_stats.reset = True
......@@ -208,14 +208,23 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if prompt and mm_items:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
else:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty token prompt."
)
# For multi-modal data, the prompt after processing should
# only contain the dummy image tokens
tokenization_kwargs = {
......
......@@ -42,6 +42,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdateDetails,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -90,6 +91,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
......
......@@ -66,6 +66,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -554,6 +555,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
# Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.
......
......@@ -76,6 +76,7 @@ from vllm.multimodal.processing.processor import (
PromptUpdateDetails,
_seq2tokens,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig
......@@ -1093,6 +1094,9 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
) -> BaseNanoNemotronVLProcessor:
raise NotImplementedError
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
......
......@@ -58,6 +58,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
)
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -608,6 +609,9 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
**kwargs,
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment