Unverified Commit 574fe752 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c61a98f5
...@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test): ...@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg = [{"role": "user", "content": "Hello"}] valid_msg = [{"role": "user", "content": "Hello"}]
long_text = "This is a very long text to test the error " * 50 long_text = "This is a very long text to test the error " * 50
invalid_msg = [{"role": "user", "content": long_text}] invalid_msg = [{"role": "user", "content": long_text}]
batch_1 = [
valid_msg, batch_1 = [valid_msg, valid_msg, invalid_msg]
valid_msg, batch_2 = [valid_msg, valid_msg]
invalid_msg,
]
batch_2 = [
valid_msg,
valid_msg,
]
sampling_params = SamplingParams(temperature=0, max_tokens=10) sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="context length is only"): with pytest.raises(ValueError, match="context length is only"):
llm.chat(batch_1, sampling_params=sampling_params) llm.chat(batch_1, sampling_params=sampling_params)
assert llm.llm_engine.get_num_unfinished_requests() == 0
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params) outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2) assert len(outputs_2) == len(batch_2)
assert llm.llm_engine.get_num_unfinished_requests() == 0 assert llm.llm_engine.get_num_unfinished_requests() == 0
...@@ -489,8 +489,9 @@ def _assert_inputs_equal( ...@@ -489,8 +489,9 @@ def _assert_inputs_equal(
if ignore_mm_keys is None: if ignore_mm_keys is None:
ignore_mm_keys = set() ignore_mm_keys = set()
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"} ignore_prompt_keys = ("prompt", "mm_kwargs")
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"} a_rest = {k: v for k, v in a.items() if k not in ignore_prompt_keys}
b_rest = {k: v for k, v in b.items() if k not in ignore_prompt_keys}
assert a_rest == b_rest, msg assert a_rest == b_rest, msg
......
...@@ -6,18 +6,17 @@ import pytest ...@@ -6,18 +6,17 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict from vllm.renderers.hf import HfRenderer
from vllm.sampling_params import SamplingParams from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.input_processor import InputProcessor
cherry_pil_image = ImageAsset("cherry_blossom").pil_image cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
def _build_input_processor( def _build_renderer(
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> InputProcessor: ) -> HfRenderer:
model_config = ModelConfig( model_config = ModelConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct", model="Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=128, max_model_len=128,
...@@ -29,47 +28,45 @@ def _build_input_processor( ...@@ -29,47 +28,45 @@ def _build_input_processor(
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
) )
return InputProcessor(vllm_config) _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
vllm_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def test_multi_modal_uuids_length_mismatch_raises(): def test_multi_modal_uuids_length_mismatch_raises():
input_processor = _build_input_processor() renderer = _build_renderer()
prompt = { mm_data = {"image": [cherry_pil_image, stop_pil_image]}
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]}, # Mismatch: 2 items but only 1 uuid provided
# Mismatch: 2 items but only 1 uuid provided mm_uuids = {"image": ["hash_cherry"]}
"multi_modal_uuids": {"image": ["hash_cherry"]},
} mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
with pytest.raises(ValueError, match="must have same length as"): with pytest.raises(ValueError, match="must have same length as"):
input_processor.process_inputs( renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-1")
request_id="req-1",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
def test_multi_modal_uuids_missing_modality_raises(): def test_multi_modal_uuids_missing_modality_raises():
input_processor = _build_input_processor() renderer = _build_renderer()
prompt = { mm_data = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:", "image": [cherry_pil_image],
# Two modalities provided in data "video": None,
"multi_modal_data": {
"image": [cherry_pil_image],
"video": None,
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {"image": ["hash_cherry"]},
} }
# Only image uuids provided; video missing should raise
mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
with pytest.raises(ValueError, match="is empty but .* is missing"): with pytest.raises(ValueError, match="is empty but .* is missing"):
input_processor.process_inputs( renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-2")
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -83,92 +80,86 @@ def test_multi_modal_uuids_missing_modality_raises(): ...@@ -83,92 +80,86 @@ def test_multi_modal_uuids_missing_modality_raises():
def test_multi_modal_uuids_accepts_none_and_passes_through( def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
): ):
input_processor = _build_input_processor( renderer = _build_renderer(
mm_cache_gb=mm_cache_gb, mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching, enable_prefix_caching=enable_prefix_caching,
) )
# Capture the overrides passed to InputPreprocessor.preprocess mm_data = {
captured: dict[str, object] = {} "image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
def fake_preprocess( }
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch.setattr(
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
)
# Use a consistent two-image scenario across all configurations # Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None} mm_uuids = {"image": [None, "hash_stop"], "video": None}
prompt = {
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
input_processor.process_inputs( mm_processor = renderer.get_mm_processor()
request_id="req-3", mm_items = mm_processor.info.parse_mm_data(mm_data)
prompt=prompt, # type: ignore[arg-type] processed_mm_uuids = renderer._process_mm_uuids(
params=SamplingParams(), mm_data, mm_items, mm_uuids, "req-3"
) )
assert captured["mm_uuids"] == mm_uuids assert processed_mm_uuids == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): @pytest.mark.parametrize(
# When both processor cache is 0 and prefix caching disabled, the "mm_cache_gb, enable_prefix_caching",
# processor builds overrides from request id instead of using user UUIDs. [
input_processor = _build_input_processor( (4.0, True), # default behavior
mm_cache_gb=0.0, enable_prefix_caching=False (4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_empty(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
renderer = _build_renderer(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
) )
captured: dict[str, MultiModalUUIDDict] = {} # While None means cached multi-modal input requiring UUIDs
# an empty list means no multi-modal input
mm_data = {"image": [], "video": []} # type: ignore[var-annotated]
mm_uuids = {"image": [], "video": None} # type: ignore[var-annotated]
def fake_preprocess( mm_processor = renderer.get_mm_processor()
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None mm_items = mm_processor.info.parse_mm_data(mm_data)
): processed_mm_uuids = renderer._process_mm_uuids(
captured["mm_uuids"] = mm_uuids mm_data, mm_items, mm_uuids, "req-4"
return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
) )
assert processed_mm_uuids == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)
request_id = "req-42" request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]} mm_data = {
prompt = { "image": [cherry_pil_image, stop_pil_image],
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:", "video": baby_reading_np_ndarrays,
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": [baby_reading_np_ndarrays],
},
"multi_modal_uuids": mm_uuids,
} }
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
input_processor.process_inputs( mm_processor = renderer.get_mm_processor()
request_id=request_id, mm_items = mm_processor.info.parse_mm_data(mm_data)
prompt=prompt, # type: ignore[arg-type] processed_mm_uuids = renderer._process_mm_uuids(
params=SamplingParams(), mm_data, mm_items, mm_uuids, request_id
) )
# Expect request-id-based overrides are passed through # Expect request-id-based overrides are passed through
assert set(mm_uuids.keys()) == {"image", "video"} assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2 assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1 assert len(mm_uuids["video"]) == 1
assert captured["mm_uuids"]["image"][0].startswith( assert processed_mm_uuids["image"][0].startswith(
f"{request_id}-image-" f"{request_id}-image-"
) and captured["mm_uuids"]["image"][0].endswith("-0") ) and processed_mm_uuids["image"][0].endswith("-0")
assert captured["mm_uuids"]["image"][1].startswith( assert processed_mm_uuids["image"][1].startswith(
f"{request_id}-image-" f"{request_id}-image-"
) and captured["mm_uuids"]["image"][1].endswith("-1") ) and processed_mm_uuids["image"][1].endswith("-1")
assert captured["mm_uuids"]["video"][0].startswith( assert processed_mm_uuids["video"][0].startswith(
f"{request_id}-video-" f"{request_id}-video-"
) and captured["mm_uuids"]["video"][0].endswith("-0") ) and processed_mm_uuids["video"][0].endswith("-0")
...@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2] ...@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
...@@ -62,7 +61,6 @@ def test_beam_search_single_input( ...@@ -62,7 +61,6 @@ def test_beam_search_single_input(
) )
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
......
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from vllm.inputs import TokenInputs, token_inputs
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
@dataclass @dataclass
...@@ -19,6 +17,8 @@ class BeamSearchSequence: ...@@ -19,6 +17,8 @@ class BeamSearchSequence:
about to be returned to the user. about to be returned to the user.
""" """
orig_prompt: TokenInputs | MultiModalInputs
# The tokens include the prompt. # The tokens include the prompt.
tokens: list[int] tokens: list[int]
logprobs: list[dict[int, Logprob]] logprobs: list[dict[int, Logprob]]
...@@ -27,8 +27,28 @@ class BeamSearchSequence: ...@@ -27,8 +27,28 @@ class BeamSearchSequence:
text: str | None = None text: str | None = None
finish_reason: str | None = None finish_reason: str | None = None
stop_reason: int | str | None = None stop_reason: int | str | None = None
multi_modal_data: "MultiModalDataDict | None" = None
mm_processor_kwargs: dict[str, Any] | None = None def get_prompt(self):
prompt = self.orig_prompt
prompt_text = prompt.get("prompt")
cache_salt = prompt.get("cache_salt")
if prompt["type"] == "token":
return token_inputs(
self.tokens,
prompt=prompt_text,
cache_salt=cache_salt,
)
return mm_inputs(
prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"],
mm_placeholders=prompt["mm_placeholders"],
prompt=prompt_text,
cache_salt=cache_salt,
)
@dataclass @dataclass
...@@ -44,14 +64,15 @@ class BeamSearchOutput: ...@@ -44,14 +64,15 @@ class BeamSearchOutput:
class BeamSearchInstance: class BeamSearchInstance:
def __init__( def __init__(
self, self,
prompt_tokens: list[int], prompt: TokenInputs | MultiModalInputs,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
logprobs: list[dict[int, Logprob]] | None = None, logprobs: list[dict[int, Logprob]] | None = None,
**kwargs, **kwargs,
): ):
self.beams: list[BeamSearchSequence] = [ self.beams: list[BeamSearchSequence] = [
BeamSearchSequence( BeamSearchSequence(
tokens=prompt_tokens, orig_prompt=prompt,
tokens=prompt["prompt_token_ids"],
logprobs=[] if logprobs is None else list(logprobs), logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request, lora_request=lora_request,
**kwargs, **kwargs,
......
...@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest, WeightTransferInitRequest,
WeightTransferUpdateRequest, WeightTransferUpdateRequest,
) )
from vllm.inputs.data import PromptType from vllm.inputs.data import ProcessorInputs, PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
...@@ -35,7 +34,7 @@ class StreamingInput: ...@@ -35,7 +34,7 @@ class StreamingInput:
where inputs are provided via an async generator. where inputs are provided via an async generator.
""" """
prompt: PromptType prompt: ProcessorInputs
sampling_params: SamplingParams | None = None sampling_params: SamplingParams | None = None
...@@ -69,8 +68,7 @@ class EngineClient(ABC): ...@@ -69,8 +68,7 @@ class EngineClient(ABC):
self, self,
prompt: EngineCoreRequest prompt: EngineCoreRequest
| PromptType | PromptType
| DictPrompt | ProcessorInputs
| TokPrompt
| AsyncGenerator[StreamingInput, None], | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
...@@ -81,6 +79,7 @@ class EngineClient(ABC): ...@@ -81,6 +79,7 @@ class EngineClient(ABC):
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: int | None = None, data_parallel_rank: int | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.""" """Generate outputs for a request."""
... ...
...@@ -88,13 +87,14 @@ class EngineClient(ABC): ...@@ -88,13 +87,14 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def encode( def encode(
self, self,
prompt: PromptType | DictPrompt | TokPrompt, prompt: PromptType | ProcessorInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.""" """Generate outputs for a request from a pooling model."""
... ...
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import itertools import itertools
import warnings import warnings
from collections.abc import Callable, Sequence from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
...@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import (
from vllm.entrypoints.utils import log_non_default_args from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs.data import ( from vllm.inputs.data import (
DataPrompt, DataPrompt,
ProcessorInputs,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
TextPrompt, TextPrompt,
...@@ -73,10 +74,8 @@ from vllm.outputs import ( ...@@ -73,10 +74,8 @@ from vllm.outputs import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, merge_kwargs from vllm.renderers import ChatParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
conversation_to_seq, conversation_to_seq,
extract_prompt_components,
parse_model_prompt, parse_model_prompt,
prompt_to_seq, prompt_to_seq,
) )
...@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike ...@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
...@@ -400,7 +400,7 @@ class LLM: ...@@ -400,7 +400,7 @@ class LLM:
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None, sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None, priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
...@@ -462,7 +462,7 @@ class LLM: ...@@ -462,7 +462,7 @@ class LLM:
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None, sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None, priority: list[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
...@@ -495,34 +495,32 @@ class LLM: ...@@ -495,34 +495,32 @@ class LLM:
# Use the same preprocessing as _run_completion # Use the same preprocessing as _run_completion
seq_prompts = prompt_to_seq(prompts) seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(sampling_params, len(seq_prompts)) seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params): seq_tok_kwargs = [
engine_prompts: Sequence[DictPrompt | TokPrompt] = [ merge_kwargs(
engine_prompt tokenization_kwargs,
for prompt, param in zip(seq_prompts, seq_params) dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
for engine_prompt in self._preprocess_cmpl( )
[prompt], for param in seq_params
tokenization_kwargs=merge_kwargs( ]
tokenization_kwargs, seq_priority = self._priority_to_seq(priority, len(prompts))
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
request_ids = self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
), ),
seq_tok_kwargs,
) )
]
else:
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
request_ids = self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
), ),
params=seq_params,
lora_requests=seq_lora_requests,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priorities=seq_priority,
) )
return request_ids return request_ids
...@@ -545,53 +543,41 @@ class LLM: ...@@ -545,53 +543,41 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs( def _resolve_lora_reqs(
self, self,
prompts: Sequence[DictPrompt | TokPrompt], prompts: Sequence[ProcessorInputs],
lora_request: list[LoRARequest] | LoRARequest | None, lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
): ):
# Grab the lora config off the vllm config on the engine,
# since this is the same for both v0 & v1.
lora_config = self.llm_engine.vllm_config.lora_config lora_config = self.llm_engine.vllm_config.lora_config
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
# If there's no lora config / default_mm_loras, or the model
# isn't multimodal, leave the lora as is.
if ( if (
lora_config is None lora_config is None
or not self.model_config.is_multimodal_model or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None) or (lora_config and lora_config.default_mm_loras is None)
): ):
return lora_request return seq_lora_requests
optional_loras = (
[lora_request] * len(prompts)
if not isinstance(lora_request, Sequence)
else lora_request
)
return [ return [
self._resolve_single_prompt_mm_lora( self._resolve_single_prompt_mm_lora(
prompt, prompt,
opt_lora_req, lora_req,
lora_config.default_mm_loras, lora_config.default_mm_loras,
) )
for prompt, opt_lora_req in zip(prompts, optional_loras) for prompt, lora_req in zip(prompts, seq_lora_requests)
] ]
def _resolve_single_prompt_mm_lora( def _resolve_single_prompt_mm_lora(
self, self,
prompt: DictPrompt | TokPrompt, prompt: ProcessorInputs,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None, default_mm_loras: dict[str, str] | None,
): ):
if not default_mm_loras or not ( if not default_mm_loras or prompt["type"] != "multimodal":
mm_data := prompt.get("multi_modal_data") or {}
):
return lora_request return lora_request
intersection = set( prompt_modalities = prompt["mm_placeholders"].keys()
mm_data.keys() # type: ignore intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
).intersection(default_mm_loras.keys())
if not intersection: if not intersection:
return lora_request return lora_request
if len(intersection) > 1: if len(intersection) > 1:
...@@ -674,22 +660,6 @@ class LLM: ...@@ -674,22 +660,6 @@ class LLM:
""" """
return self.llm_engine.apply_model(func) return self.llm_engine.apply_model(func)
def _get_beam_search_lora_requests(
self,
lora_request: list[LoRARequest] | LoRARequest | None,
prompts: list[TokensPrompt | TextPrompt],
) -> list[LoRARequest | None]:
"""Get the optional lora request corresponding to each prompt."""
if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
raise ValueError(
"Lora request list should be the same length as the prompts"
)
if lora_request is None or isinstance(lora_request, LoRARequest):
return [lora_request] * len(prompts)
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
def beam_search( def beam_search(
self, self,
prompts: list[TokensPrompt | TextPrompt], prompts: list[TokensPrompt | TextPrompt],
...@@ -718,13 +688,12 @@ class LLM: ...@@ -718,13 +688,12 @@ class LLM:
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
length_penalty = params.length_penalty length_penalty = params.length_penalty
lora_requests = self._get_beam_search_lora_requests(lora_request, prompts) tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
tokenizer = self.get_tokenizer() engine_prompts = self._preprocess_cmpl(prompts)
sort_beams_key = create_sort_beams_key_function( lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
tokenizer.eos_token_id,
length_penalty,
)
if use_tqdm and concurrency_limit is not None: if use_tqdm and concurrency_limit is not None:
logger.warning( logger.warning(
...@@ -734,21 +703,12 @@ class LLM: ...@@ -734,21 +703,12 @@ class LLM:
use_tqdm = False use_tqdm = False
if concurrency_limit is None: if concurrency_limit is None:
concurrency_limit = len(prompts) concurrency_limit = len(engine_prompts)
def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
if beam.multi_modal_data is not None:
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
if beam.mm_processor_kwargs is not None:
token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation # following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams( sampling_params = SamplingParams(
logprobs=2 * beam_width, logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=temperature, temperature=temperature,
...@@ -756,30 +716,25 @@ class LLM: ...@@ -756,30 +716,25 @@ class LLM:
) )
instances: list[BeamSearchInstance] = [] instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, prompts): for lora_req, prompt in zip(lora_requests, engine_prompts):
# Add multimodal processor kwargs & data if prompt["type"] == "embeds":
mm_kwargs = {} raise NotImplementedError(
if "multi_modal_data" in prompt: "Embedding prompt not supported for beam search"
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] )
if "mm_processor_kwargs" in prompt: if prompt["type"] == "enc_dec":
mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"] raise NotImplementedError(
"Encoder-decoder prompt not supported for beam search"
if "prompt_token_ids" in prompt: )
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append( instances.append(
BeamSearchInstance( BeamSearchInstance(
prompt_tokens, prompt,
lora_request=lora_req, lora_request=lora_req,
logprobs=None, logprobs=None,
**mm_kwargs,
), ),
) )
for prompt_start in range(0, len(prompts), concurrency_limit): for prompt_start in range(0, len(instances), concurrency_limit):
instances_batch = instances[prompt_start : prompt_start + concurrency_limit] instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
token_iter = range(max_tokens) token_iter = range(max_tokens)
...@@ -808,22 +763,15 @@ class LLM: ...@@ -808,22 +763,15 @@ class LLM:
if len(all_beams) == 0: if len(all_beams) == 0:
break break
# create corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[
(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams
]
)
# only runs for one step # only runs for one step
# we don't need to use tqdm here # we don't need to use tqdm here
output = self.generate( raw_output = self._render_and_run_requests(
prompts_batch, prompts=(beam.get_prompt() for beam in all_beams),
sampling_params=beam_search_params, params=self._params_to_seq(sampling_params, len(all_beams)),
lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False, use_tqdm=False,
lora_request=lora_req_batch,
) )
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
for (start, end), instance in zip( for (start, end), instance in zip(
instance_start_and_end, instances_batch instance_start_and_end, instances_batch
...@@ -841,19 +789,15 @@ class LLM: ...@@ -841,19 +789,15 @@ class LLM:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
current_beam.orig_prompt,
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs], logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request, lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob, + logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
) )
if ( if token_id == eos_token_id and not ignore_eos:
token_id == tokenizer.eos_token_id
and not ignore_eos
):
instance.completed.append(new_beam) instance.completed.append(new_beam)
else: else:
instance_new_beams.append(new_beam) instance_new_beams.append(new_beam)
...@@ -872,6 +816,7 @@ class LLM: ...@@ -872,6 +816,7 @@ class LLM:
for beam in best_beams: for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens) beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams)) outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs return outputs
...@@ -880,7 +825,7 @@ class LLM: ...@@ -880,7 +825,7 @@ class LLM:
self, self,
prompts: Sequence[PromptType], prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[DictPrompt | TokPrompt]: ) -> Sequence[ProcessorInputs]:
""" """
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`. a format that can be passed to `_add_request`.
...@@ -888,8 +833,7 @@ class LLM: ...@@ -888,8 +833,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments. Refer to [LLM.generate][] for a complete description of the arguments.
Returns: Returns:
A list of `TokPrompt` objects containing the tokenized prompt A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
...@@ -903,6 +847,14 @@ class LLM: ...@@ -903,6 +847,14 @@ class LLM:
return renderer.render_cmpl(parsed_prompts, tok_params) return renderer.render_cmpl(parsed_prompts, tok_params)
def _preprocess_cmpl_one(
self,
prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
(engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
return engine_prompt
def _preprocess_chat( def _preprocess_chat(
self, self,
conversations: Sequence[list[ChatCompletionMessageParam]], conversations: Sequence[list[ChatCompletionMessageParam]],
...@@ -914,7 +866,7 @@ class LLM: ...@@ -914,7 +866,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> Sequence[TokPrompt]: ) -> Sequence[ProcessorInputs]:
""" """
Convert a list of conversations into prompts so that they can then Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs. be used as input for other LLM APIs.
...@@ -922,8 +874,7 @@ class LLM: ...@@ -922,8 +874,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments. Refer to [LLM.chat][] for a complete description of the arguments.
Returns: Returns:
A list of `TokPrompt` objects containing the tokenized prompt A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.renderer renderer = self.renderer
...@@ -953,13 +904,39 @@ class LLM: ...@@ -953,13 +904,39 @@ class LLM:
return engine_prompts return engine_prompts
def _preprocess_chat_one(
self,
conversation: list[ChatCompletionMessageParam],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs:
(engine_prompt,) = self._preprocess_chat(
[conversation],
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
return engine_prompt
def chat( def chat(
self, self,
messages: list[ChatCompletionMessageParam] messages: list[ChatCompletionMessageParam]
| Sequence[list[ChatCompletionMessageParam]], | Sequence[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None, sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None, chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
...@@ -1805,47 +1782,41 @@ class LLM: ...@@ -1805,47 +1782,41 @@ class LLM:
| Sequence[SamplingParams | PoolingParams], | Sequence[SamplingParams | PoolingParams],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None, priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
): ):
seq_prompts = prompt_to_seq(prompts) seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts)) seq_params = self._params_to_seq(params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params): seq_tok_kwargs = [
# TODO: Remove this after deprecating `param.truncate_prompt_tokens` merge_kwargs(
# Then, move the code from the `else` block to the top and let tokenization_kwargs,
# `self._preprocess_cmpl` handle prompt normalization dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
engine_prompts: Sequence[DictPrompt | TokPrompt] = [ )
engine_prompt for param in seq_params
for prompt, param in zip(seq_prompts, seq_params) ]
for engine_prompt in self._preprocess_cmpl( seq_priority = self._priority_to_seq(priority, len(prompts))
[prompt],
tokenization_kwargs=merge_kwargs( return self._render_and_run_requests(
tokenization_kwargs, prompts=(
dict(truncate_prompt_tokens=param.truncate_prompt_tokens), self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
), ),
seq_tok_kwargs,
) )
] ),
else:
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params, params=seq_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs( lora_requests=seq_lora_requests,
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priorities=seq_priority,
) )
return self._run_engine(use_tqdm=use_tqdm)
def _run_chat( def _run_chat(
self, self,
messages: list[ChatCompletionMessageParam] messages: list[ChatCompletionMessageParam]
...@@ -1855,7 +1826,7 @@ class LLM: ...@@ -1855,7 +1826,7 @@ class LLM:
| Sequence[SamplingParams | PoolingParams], | Sequence[SamplingParams | PoolingParams],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None, chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
...@@ -1865,68 +1836,94 @@ class LLM: ...@@ -1865,68 +1836,94 @@ class LLM:
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
): ):
engine_prompts = self._preprocess_chat( seq_convs = conversation_to_seq(messages)
conversation_to_seq(messages), seq_params = self._params_to_seq(params, len(seq_convs))
chat_template=chat_template, seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
chat_template_content_format=chat_template_content_format, seq_tok_kwargs = [
chat_template_kwargs=chat_template_kwargs, merge_kwargs(
add_generation_prompt=add_generation_prompt, tokenization_kwargs,
continue_final_message=continue_final_message, dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
tools=tools, )
for param in seq_params
]
return self._render_and_run_requests(
prompts=(
self._preprocess_chat_one(
conversation,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tok_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
for conversation, tok_kwargs in zip(
maybe_tqdm(
seq_convs,
use_tqdm=use_tqdm,
desc="Rendering conversations",
),
seq_tok_kwargs,
)
),
params=seq_params,
lora_requests=seq_lora_requests,
use_tqdm=use_tqdm,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
) )
self._validate_and_add_requests( def _render_and_run_requests(
prompts=engine_prompts, self,
prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams],
*,
lora_requests: Sequence[LoRARequest | None] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priorities: Sequence[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
):
if isinstance(prompts, (list, tuple)):
logger.warning_once(
"Rendering all prompts before adding them to the engine "
"is less efficient than performing both on the same prompt "
"before processing the next prompt. You should instead pass "
"a generator that renders one prompt per iteration, as that allows "
"engine execution to begin for the first prompt while processing "
"the next prompt."
)
self._render_and_add_requests(
prompts=prompts,
params=params, params=params,
use_tqdm=use_tqdm, lora_requests=lora_requests,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priorities=priorities,
) )
return self._run_engine(use_tqdm=use_tqdm) return self._run_engine(use_tqdm=use_tqdm)
def _validate_and_add_requests( def _render_and_add_requests(
self, self,
prompts: Sequence[DictPrompt | TokPrompt], prompts: Iterable[ProcessorInputs],
params: SamplingParams params: Sequence[SamplingParams | PoolingParams],
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, lora_requests: Sequence[LoRARequest | None] | None = None,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None, priorities: Sequence[int] | None = None,
) -> list[str]: ) -> list[str]:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
seq_priority = self._priority_to_seq(priority, num_requests)
for sp in seq_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
added_request_ids: list[str] = [] added_request_ids: list[str] = []
try: try:
for i, prompt in enumerate(it): for i, prompt in enumerate(prompts):
request_id = self._add_request( request_id = self._add_request(
prompt, prompt,
seq_params[i], params[i],
lora_request=seq_lora_requests[i], lora_request=None if lora_requests is None else lora_requests[i],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=seq_priority[i], priority=0 if priorities is None else priorities[i],
) )
added_request_ids.append(request_id) added_request_ids.append(request_id)
except Exception as e: except Exception as e:
...@@ -1938,13 +1935,16 @@ class LLM: ...@@ -1938,13 +1935,16 @@ class LLM:
def _add_request( def _add_request(
self, self,
prompt: PromptType | DictPrompt | TokPrompt, prompt: ProcessorInputs,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0, priority: int = 0,
) -> str: ) -> str:
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) if isinstance(params, SamplingParams):
# We only care about the final output
params.output_kind = RequestOutputKind.FINAL_ONLY
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
if params.truncate_prompt_tokens is not None: if params.truncate_prompt_tokens is not None:
...@@ -1962,32 +1962,14 @@ class LLM: ...@@ -1962,32 +1962,14 @@ class LLM:
dict(truncate_prompt_tokens=params.truncate_prompt_tokens), dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
) )
renderer = self.renderer return self.llm_engine.add_request(
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id, request_id,
prompt, prompt,
params, params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
supported_tasks=self.supported_tasks,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
) )
return engine_request.request_id
def _run_engine( def _run_engine(
self, self,
......
...@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
) )
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import ProcessorInputs, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import ( from vllm.tokenizers.mistral import (
...@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse: ) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
""" """
render chat request by validating and preprocessing inputs. render chat request by validating and preprocessing inputs.
...@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt) prompt_token_ids = self._extract_prompt_components(
engine_prompt
).token_ids
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
...@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing): ...@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
tok_params = request.build_tok_params(self.model_config) reasoning_ended = (
tokenization_kwargs = tok_params.get_encode_kwargs() reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
engine_request = self.input_processor.process_inputs( else None
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
) )
reasoning_ended = None
if reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
engine_request.prompt_token_ids or [] # type: ignore[attr-defined]
)
engine_request.reasoning_ended = reasoning_ended
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_request, engine_prompt,
sampling_params, sampling_params,
sub_request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
reasoning_ended=reasoning_ended,
) )
generators.append(generator) generators.append(generator)
......
...@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
...@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request( async def render_completion_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> list[TokPrompt] | ErrorResponse: ) -> list[ProcessorInputs] | ErrorResponse:
""" """
render completion request by validating and preprocessing inputs. render completion request by validating and preprocessing inputs.
...@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len, max_model_len,
request.max_tokens, request.max_tokens,
...@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_request, engine_prompt,
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )
...@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
engine_prompts: list[TokPrompt], engine_prompts: list[ProcessorInputs],
result_generator: AsyncIterator[tuple[int, RequestOutput]], result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str, request_id: str,
created_time: int, created_time: int,
......
...@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import (
) )
from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonPrompt,
TokensPrompt,
token_inputs,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
extract_prompt_components, extract_prompt_components,
extract_prompt_len, extract_prompt_len,
...@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]): ...@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]):
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokPrompt] | None = None engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None None
...@@ -249,7 +253,7 @@ class OpenAIServing: ...@@ -249,7 +253,7 @@ class OpenAIServing:
async def beam_search( async def beam_search(
self, self,
prompt: TokPrompt, prompt: ProcessorInputs,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -262,86 +266,53 @@ class OpenAIServing: ...@@ -262,86 +266,53 @@ class OpenAIServing:
length_penalty = params.length_penalty length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output include_stop_str_in_output = params.include_stop_str_in_output
input_processor = self.input_processor tokenizer = self.renderer.get_tokenizer()
tokenizer = input_processor.tokenizer eos_token_id = tokenizer.eos_token_id
if tokenizer is None: sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
raise VLLMValidationError(
"You cannot use beam search when `skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
)
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
# This is a workaround to fix multimodal beam search; this is a if prompt["type"] == "embeds":
# bandaid fix for 2 small problems: raise NotImplementedError("Embedding prompt not supported for beam search")
# 1. Multi_modal_data on the processed_inputs currently resolves to if prompt["type"] == "enc_dec":
# `None`. raise NotImplementedError(
# 2. preprocessing above expands the multimodal placeholders. However, "Encoder-decoder prompt not supported for beam search"
# this happens again in generation, so the double expansion causes )
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_text = prompt.get("prompt")
prompt_token_ids = prompt["prompt_token_ids"]
tokenized_length = len(prompt_token_ids) tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
logprobs_num = 2 * beam_width logprobs_num = 2 * beam_width
beam_search_params = SamplingParams( sampling_params = SamplingParams(
logprobs=logprobs_num, logprobs=logprobs_num,
max_tokens=1, max_tokens=1,
temperature=temperature, temperature=temperature,
) )
all_beams = [ all_beams = [
BeamSearchSequence( BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt_token_ids, tokens=prompt_token_ids,
cum_logprob=0, cum_logprob=0,
logprobs=[], logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request, lora_request=lora_request,
) )
] ]
completed = [] completed = []
for _ in range(max_tokens): for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(
*[
(
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
),
beam.lora_request,
)
for beam in all_beams
]
)
tasks = [] tasks = []
request_id_batch = f"{request_id}-{random_uuid()}" request_id_batch = f"{request_id}-{random_uuid()}"
for i, (individual_prompt, lora_req) in enumerate( for i, beam in enumerate(all_beams):
zip(prompts_batch, lora_req_batch) prompt_item = beam.get_prompt()
): lora_request_item = beam.lora_request
request_id_item = f"{request_id_batch}-beam-{i}" request_id_item = f"{request_id_batch}-beam-{i}"
task = asyncio.create_task( task = asyncio.create_task(
collect_from_async_generator( collect_from_async_generator(
self.engine_client.generate( self.engine_client.generate(
individual_prompt, prompt_item,
beam_search_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_req, lora_request=lora_request_item,
trace_headers=trace_headers, trace_headers=trace_headers,
) )
) )
...@@ -406,6 +377,7 @@ class OpenAIServing: ...@@ -406,6 +377,7 @@ class OpenAIServing:
logprobs_entry = result.outputs[0].logprobs[0] logprobs_entry = result.outputs[0].logprobs[0]
completed.append( completed.append(
BeamSearchSequence( BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [eos_token_id] tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output if include_stop_str_in_output
else current_beam.tokens, else current_beam.tokens,
...@@ -433,12 +405,11 @@ class OpenAIServing: ...@@ -433,12 +405,11 @@ class OpenAIServing:
logprobs_entry = result.outputs[0].logprobs[0] logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append( new_beams.append(
BeamSearchSequence( BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs_entry], logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request, lora_request=current_beam.lora_request,
cum_logprob=float(all_beams_logprob[idx]), cum_logprob=float(all_beams_logprob[idx]),
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
) )
) )
...@@ -958,7 +929,7 @@ class OpenAIServing: ...@@ -958,7 +929,7 @@ class OpenAIServing:
request: RendererRequest, request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]: ) -> list[ProcessorInputs]:
prompts = list[SingletonPrompt | bytes]() prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds)) prompts.extend(prompt_to_seq(prompt_embeds))
...@@ -971,7 +942,7 @@ class OpenAIServing: ...@@ -971,7 +942,7 @@ class OpenAIServing:
self, self,
request: RendererRequest, request: RendererRequest,
prompts: Sequence[PromptType | bytes], prompts: Sequence[PromptType | bytes],
) -> list[TokPrompt]: ) -> list[ProcessorInputs]:
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
...@@ -1004,7 +975,7 @@ class OpenAIServing: ...@@ -1004,7 +975,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]: ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer renderer = self.renderer
...@@ -1052,13 +1023,13 @@ class OpenAIServing: ...@@ -1052,13 +1023,13 @@ class OpenAIServing:
return conversation, [engine_prompt] return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object): def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
return extract_prompt_components(self.model_config, prompt) return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object): def _extract_prompt_text(self, prompt: ProcessorInputs):
return self._extract_prompt_components(prompt).text return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object): def _extract_prompt_len(self, prompt: ProcessorInputs):
return extract_prompt_len(self.model_config, prompt) return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn( async def _render_next_turn(
...@@ -1088,16 +1059,14 @@ class OpenAIServing: ...@@ -1088,16 +1059,14 @@ class OpenAIServing:
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
engine_prompt: TokPrompt, engine_prompt: ProcessorInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext, context: ConversationContext,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
): ):
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1112,26 +1081,13 @@ class OpenAIServing: ...@@ -1112,26 +1081,13 @@ class OpenAIServing:
lora_request=lora_request, lora_request=lora_request,
) )
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_request, engine_prompt,
sampling_params, sampling_params,
sub_request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
) )
async for res in generator: async for res in generator:
...@@ -1154,11 +1110,11 @@ class OpenAIServing: ...@@ -1154,11 +1110,11 @@ class OpenAIServing:
# Render the next prompt token ids and update sampling_params. # Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion() token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids) engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids) sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn( (engine_prompt,) = await self._render_next_turn(
context.request, context.request,
context.parser.response_messages, context.parser.response_messages,
context.tool_dicts, context.tool_dicts,
...@@ -1166,8 +1122,6 @@ class OpenAIServing: ...@@ -1166,8 +1122,6 @@ class OpenAIServing:
context.chat_template, context.chat_template,
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0]
prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
max_model_len, max_model_len,
...@@ -1184,7 +1138,7 @@ class OpenAIServing: ...@@ -1184,7 +1138,7 @@ class OpenAIServing:
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: PromptType | TokPrompt, inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
......
...@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels ...@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime from vllm.model_executor.models.interfaces import SupportsRealtime
from vllm.renderers.inputs.preprocess import parse_model_prompt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing): ...@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing):
Yields: Yields:
StreamingInput objects containing audio prompts for the engine StreamingInput objects containing audio prompts for the engine
""" """
model_config = self.model_config
renderer = self.renderer
# mypy is being stupid # mypy is being stupid
# TODO(Patrick) - fix this # TODO(Patrick) - fix this
stream_input_iter = cast( stream_input_iter = cast(
AsyncGenerator[PromptType, None], AsyncGenerator[PromptType, None],
self.model_cls.buffer_realtime_audio( self.model_cls.buffer_realtime_audio(
audio_stream, input_stream, self.model_config audio_stream, input_stream, model_config
), ),
) )
async for prompt in stream_input_iter: async for prompt in stream_input_iter:
yield StreamingInput(prompt=prompt) parsed_prompt = parse_model_prompt(model_config, prompt)
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
yield StreamingInput(prompt=engine_prompt)
...@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod ...@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from dataclasses import replace from dataclasses import replace
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Final, Union
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem, ResponseFunctionToolCallOutputItem,
...@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext): ...@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext):
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.input_messages: list[ResponseRawMessageAndToken] = [] self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = [] self.output_messages: list[ResponseRawMessageAndToken] = []
......
...@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import (
) )
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input( def _validate_generator_input(
self, self,
engine_prompt: TokPrompt, engine_prompt: ProcessorInputs,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
prompt_len = self._extract_prompt_len(engine_prompt) prompt_len = self._extract_prompt_len(engine_prompt)
...@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params default_max_tokens, self.default_sampling_params
) )
tok_params = request.build_tok_params(self.model_config)
trace_headers = ( trace_headers = (
None None
...@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id, request_id=request.request_id,
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
tok_params=tok_params,
context=context, context=context,
lora_request=lora_request, lora_request=lora_request,
priority=request.priority, priority=request.priority,
...@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing):
messages = self._construct_input_messages_with_harmony(request, prev_response) messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = token_inputs(prompt_token_ids)
# Add cache_salt if provided in the request # Add cache_salt if provided in the request
if request.cache_salt is not None: if request.cache_salt is not None:
......
...@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationSegment, TranslationSegment,
TranslationStreamResponse, TranslationStreamResponse,
) )
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
...@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing):
return return
try: try:
from vllm.sampling_params import SamplingParams
warmup_start = time.perf_counter() warmup_start = time.perf_counter()
logger.info("Warming up multimodal input processor...") logger.info("Warming up multimodal input processor...")
...@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt="", request_prompt="",
to_language=None, to_language=None,
) )
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
# Create minimal sampling params
dummy_params = SamplingParams(
max_tokens=1,
temperature=0.0,
skip_clone=True, # Internal warmup, safe to skip clone
)
# Process the dummy input through the input processor # Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization # This will trigger all the multimodal processing initialization
_ = self.input_processor.process_inputs( _ = self.renderer.render_cmpl([parsed_prompt])
request_id="warmup",
prompt=dummy_prompt,
params=dummy_params,
)
warmup_elapsed = time.perf_counter() - warmup_start warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed) logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
...@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing):
self, self,
request: SpeechToTextRequest, request: SpeechToTextRequest,
audio_data: bytes, audio_data: bytes,
) -> tuple[list[PromptType], float]: ) -> tuple[list[ProcessorInputs], float]:
# Validate request # Validate request
language = self.model_cls.validate_language(request.language) language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper. # Skip to_language validation to avoid extra logging for Whisper.
...@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing):
and duration > self.asr_config.max_audio_clip_s and duration > self.asr_config.max_audio_clip_s
) )
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = [] parsed_prompts: list[DictPrompt] = []
for chunk in chunks: for chunk in chunks:
# The model has control over the construction, as long as it # The model has control over the construction, as long as it
# returns a valid PromptType. # returns a valid PromptType.
...@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt=request.prompt, request_prompt=request.prompt,
to_language=to_language, to_language=to_language,
) )
parsed_prompt: DictPrompt
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt)) parsed_prompt = parse_enc_dec_prompt(prompt)
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
else:
parsed_prompt = parse_model_prompt(self.model_config, prompt)
parsed_prompts.append(parsed_prompt)
prompts.append(prompt) engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
return prompts, duration return engine_prompts, duration
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt): def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
...@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
prompts, duration_s = await self._preprocess_speech_to_text( engine_prompts, duration_s = await self._preprocess_speech_to_text(
request=request, request=request,
audio_data=audio_data, audio_data=audio_data,
) )
...@@ -445,57 +441,54 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -445,57 +441,54 @@ class OpenAISpeechToText(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e) return self.create_error_response(e)
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try: try:
# Unlike most decoder-only models, whisper generation length is not # Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a # constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg. # generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None: max_tokens = get_max_tokens(
default_max_tokens = self.model_config.max_model_len max_model_len,
else: request.max_completion_tokens,
default_max_tokens = min( 0,
self.model_config.max_model_len, request.max_completion_tokens self.default_sampling_params,
) )
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params max_tokens,
self.default_sampling_params,
) )
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
sampling_params.logprobs = 1 sampling_params.logprobs = 1
self._log_inputs(
request_id,
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
list_result_generator = [] list_result_generator = []
for i, prompt in enumerate(prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}_{i}" request_id_item = f"{request_id}_{i}"
engine_request = self.input_processor.process_inputs(
self._log_inputs(
request_id_item, request_id_item,
prompt, engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params, sampling_params,
request_id_item,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=0,
)
list_result_generator.append(
self.engine_client.generate(
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
)
) )
list_result_generator.append(generator)
except ValueError as e: except ValueError as e:
return self.create_error_response(e) return self.create_error_response(e)
......
...@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness from vllm.utils.serial_utils import EmbedDType, Endianness
...@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing):
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk # Create engine prompt for this chunk
chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens) chunk_engine_prompt = token_inputs(chunk_tokens)
# Log the chunk # Log the chunk
self._log_inputs( self._log_inputs(
...@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices # Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode( original_generator = self.engine_client.encode(
chunk_engine_prompt, chunk_engine_prompt,
pooling_params, pooling_params,
chunk_request_id, chunk_request_id,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=ctx.request.priority, priority=ctx.request.priority,
) )
...@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: TokPrompt, engine_prompt: ProcessorInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
...@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping # Return the original generator without wrapping
return self.engine_client.encode( return self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=ctx.request.priority, priority=ctx.request.priority,
) )
......
...@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.inputs import PromptType from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported" "dimensions is currently not supported"
) )
engine_prompts: Sequence[PromptType | TokPrompt] engine_prompts: Sequence[ProcessorInputs]
if use_io_processor := isinstance(request, IOProcessorRequest): if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
...@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing):
else: else:
pooling_params = request.to_pooling_params() # type: ignore pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
...@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
) )
......
...@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
get_score_prompt, get_score_prompt,
validate_score_input, validate_score_input,
) )
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
...@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing): ...@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts) *(encode_async(t, **tokenization_kwargs) for t in input_texts)
) )
engine_prompts: list[TokensPrompt] = [] engine_prompts: list[ProcessorInputs] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts): for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text) text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append( engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"]) token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
) )
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
...@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing): ...@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
input_texts[i], engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing): ...@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts) *(encode_async(t, **tokenization_kwargs) for t in input_texts)
) )
engine_prompts: list[TokensPrompt] = [] engine_prompts: list[ProcessorInputs] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts): for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text) text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append( engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"]) token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
) )
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
...@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing): ...@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
input_texts[i], engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
......
...@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import ( ...@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse, GenerateResponse,
GenerateResponseChoice, GenerateResponseChoice,
) )
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing): ...@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
TokensPrompt(prompt_token_ids=request.token_ids), engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing): ...@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing):
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
result_generator = self.engine_client.generate( result_generator = self.engine_client.generate(
engine_request, engine_prompt,
sampling_params, sampling_params,
request_id, request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
tokenization_kwargs=tokenization_kwargs,
) )
except ValueError as e: except ValueError as e:
......
...@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse, TokenizeResponse,
TokenizerInfoResponse, TokenizerInfoResponse,
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
TokensPrompt(prompt_token_ids=request.tokens), token_inputs(request.tokens),
params=None, params=None,
lora_request=lora_request, lora_request=lora_request,
) )
......
...@@ -187,6 +187,9 @@ class _InputOptions(TypedDict): ...@@ -187,6 +187,9 @@ class _InputOptions(TypedDict):
Additional options available to all input types. Additional options available to all input types.
""" """
arrival_time: NotRequired[float]
"""The time when the input was received (before rendering)."""
cache_salt: NotRequired[str] cache_salt: NotRequired[str]
"""Optional cache salt to be used for prefix caching.""" """Optional cache salt to be used for prefix caching."""
...@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict): ...@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict):
decoder_prompt: DecoderInputs decoder_prompt: DecoderInputs
"""The inputs for the decoder portion.""" """The inputs for the decoder portion."""
arrival_time: NotRequired[float]
"""The time when the input was received (before rendering)."""
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment