Unverified Commit 574fe752 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c61a98f5
......@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg = [{"role": "user", "content": "Hello"}]
long_text = "This is a very long text to test the error " * 50
invalid_msg = [{"role": "user", "content": long_text}]
batch_1 = [
valid_msg,
valid_msg,
invalid_msg,
]
batch_2 = [
valid_msg,
valid_msg,
]
batch_1 = [valid_msg, valid_msg, invalid_msg]
batch_2 = [valid_msg, valid_msg]
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="context length is only"):
llm.chat(batch_1, sampling_params=sampling_params)
assert llm.llm_engine.get_num_unfinished_requests() == 0
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2)
assert llm.llm_engine.get_num_unfinished_requests() == 0
......@@ -489,8 +489,9 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
ignore_prompt_keys = ("prompt", "mm_kwargs")
a_rest = {k: v for k, v in a.items() if k not in ignore_prompt_keys}
b_rest = {k: v for k, v in b.items() if k not in ignore_prompt_keys}
assert a_rest == b_rest, msg
......
......@@ -6,18 +6,17 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.input_processor import InputProcessor
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
def _build_input_processor(
def _build_renderer(
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> InputProcessor:
) -> HfRenderer:
model_config = ModelConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=128,
......@@ -29,47 +28,45 @@ def _build_input_processor(
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
)
return InputProcessor(vllm_config)
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
vllm_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def test_multi_modal_uuids_length_mismatch_raises():
input_processor = _build_input_processor()
renderer = _build_renderer()
mm_data = {"image": [cherry_pil_image, stop_pil_image]}
prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]},
# Mismatch: 2 items but only 1 uuid provided
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
with pytest.raises(ValueError, match="must have same length as"):
input_processor.process_inputs(
request_id="req-1",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-1")
def test_multi_modal_uuids_missing_modality_raises():
input_processor = _build_input_processor()
renderer = _build_renderer()
prompt = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
# Two modalities provided in data
"multi_modal_data": {
mm_data = {
"image": [cherry_pil_image],
"video": None,
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
# Only image uuids provided; video missing should raise
mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
with pytest.raises(ValueError, match="is empty but .* is missing"):
input_processor.process_inputs(
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-2")
@pytest.mark.parametrize(
......@@ -83,92 +80,86 @@ def test_multi_modal_uuids_missing_modality_raises():
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
input_processor = _build_input_processor(
renderer = _build_renderer(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
# Capture the overrides passed to InputPreprocessor.preprocess
captured: dict[str, object] = {}
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch.setattr(
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
)
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
prompt = {
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
"multi_modal_data": {
mm_data = {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
input_processor.process_inputs(
request_id="req-3",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, "req-3"
)
assert captured["mm_uuids"] == mm_uuids
assert processed_mm_uuids == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
input_processor = _build_input_processor(
mm_cache_gb=0.0, enable_prefix_caching=False
@pytest.mark.parametrize(
"mm_cache_gb, enable_prefix_caching",
[
(4.0, True), # default behavior
(4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_empty(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
renderer = _build_renderer(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
captured: dict[str, MultiModalUUIDDict] = {}
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]}
# While None means cached multi-modal input requiring UUIDs
# an empty list means no multi-modal input
mm_data = {"image": [], "video": []} # type: ignore[var-annotated]
mm_uuids = {"image": [], "video": None} # type: ignore[var-annotated]
monkeypatch.setattr(
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, "req-4"
)
assert processed_mm_uuids == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)
request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
prompt = {
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
"multi_modal_data": {
mm_data = {
"image": [cherry_pil_image, stop_pil_image],
"video": [baby_reading_np_ndarrays],
},
"multi_modal_uuids": mm_uuids,
"video": baby_reading_np_ndarrays,
}
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
input_processor.process_inputs(
request_id=request_id,
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, request_id
)
# Expect request-id-based overrides are passed through
assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1
assert captured["mm_uuids"]["image"][0].startswith(
assert processed_mm_uuids["image"][0].startswith(
f"{request_id}-image-"
) and captured["mm_uuids"]["image"][0].endswith("-0")
assert captured["mm_uuids"]["image"][1].startswith(
) and processed_mm_uuids["image"][0].endswith("-0")
assert processed_mm_uuids["image"][1].startswith(
f"{request_id}-image-"
) and captured["mm_uuids"]["image"][1].endswith("-1")
assert captured["mm_uuids"]["video"][0].startswith(
) and processed_mm_uuids["image"][1].endswith("-1")
assert processed_mm_uuids["video"][0].startswith(
f"{request_id}-video-"
) and captured["mm_uuids"]["video"][0].endswith("-0")
) and processed_mm_uuids["video"][0].endswith("-0")
......@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
......@@ -62,7 +61,6 @@ def test_beam_search_single_input(
)
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
......
......@@ -2,13 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from vllm.inputs import TokenInputs, token_inputs
from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
@dataclass
......@@ -19,6 +17,8 @@ class BeamSearchSequence:
about to be returned to the user.
"""
orig_prompt: TokenInputs | MultiModalInputs
# The tokens include the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
......@@ -27,8 +27,28 @@ class BeamSearchSequence:
text: str | None = None
finish_reason: str | None = None
stop_reason: int | str | None = None
multi_modal_data: "MultiModalDataDict | None" = None
mm_processor_kwargs: dict[str, Any] | None = None
def get_prompt(self):
prompt = self.orig_prompt
prompt_text = prompt.get("prompt")
cache_salt = prompt.get("cache_salt")
if prompt["type"] == "token":
return token_inputs(
self.tokens,
prompt=prompt_text,
cache_salt=cache_salt,
)
return mm_inputs(
prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"],
mm_placeholders=prompt["mm_placeholders"],
prompt=prompt_text,
cache_salt=cache_salt,
)
@dataclass
......@@ -44,14 +64,15 @@ class BeamSearchOutput:
class BeamSearchInstance:
def __init__(
self,
prompt_tokens: list[int],
prompt: TokenInputs | MultiModalInputs,
lora_request: LoRARequest | None = None,
logprobs: list[dict[int, Logprob]] | None = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(
tokens=prompt_tokens,
orig_prompt=prompt,
tokens=prompt["prompt_token_ids"],
logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request,
**kwargs,
......
......@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import PromptType
from vllm.inputs.data import ProcessorInputs, PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest
......@@ -35,7 +34,7 @@ class StreamingInput:
where inputs are provided via an async generator.
"""
prompt: PromptType
prompt: ProcessorInputs
sampling_params: SamplingParams | None = None
......@@ -69,8 +68,7 @@ class EngineClient(ABC):
self,
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| ProcessorInputs
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
......@@ -81,6 +79,7 @@ class EngineClient(ABC):
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
......@@ -88,13 +87,14 @@ class EngineClient(ABC):
@abstractmethod
def encode(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType | ProcessorInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""
...
......
This diff is collapsed.
......@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
......@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
"""
render chat request by validating and preprocessing inputs.
......@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt)
prompt_token_ids = self._extract_prompt_components(
engine_prompt
).token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
......@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
reasoning_ended = None
if reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
engine_request.prompt_token_ids or [] # type: ignore[attr-defined]
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
engine_request.reasoning_ended = reasoning_ended
generator = self.engine_client.generate(
engine_request,
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
reasoning_ended=reasoning_ended,
)
generators.append(generator)
......
......@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
......@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[TokPrompt] | ErrorResponse:
) -> list[ProcessorInputs] | ErrorResponse:
"""
render completion request by validating and preprocessing inputs.
......@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens(
max_model_len,
request.max_tokens,
......@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers,
)
else:
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(
engine_request,
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
......@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[TokPrompt],
engine_prompts: list[ProcessorInputs],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
......
......@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonPrompt,
TokensPrompt,
token_inputs,
)
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
extract_prompt_len,
......@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokPrompt] | None = None
engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
......@@ -249,7 +253,7 @@ class OpenAIServing:
async def beam_search(
self,
prompt: TokPrompt,
prompt: ProcessorInputs,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
......@@ -262,86 +266,53 @@ class OpenAIServing:
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
input_processor = self.input_processor
tokenizer = input_processor.tokenizer
if tokenizer is None:
raise VLLMValidationError(
"You cannot use beam search when `skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
)
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
if prompt["type"] == "enc_dec":
raise NotImplementedError(
"Encoder-decoder prompt not supported for beam search"
)
prompt_text = prompt.get("prompt")
prompt_token_ids = prompt["prompt_token_ids"]
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
logprobs_num = 2 * beam_width
beam_search_params = SamplingParams(
sampling_params = SamplingParams(
logprobs=logprobs_num,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(
*[
(
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
),
beam.lora_request,
)
for beam in all_beams
]
)
tasks = []
request_id_batch = f"{request_id}-{random_uuid()}"
for i, (individual_prompt, lora_req) in enumerate(
zip(prompts_batch, lora_req_batch)
):
for i, beam in enumerate(all_beams):
prompt_item = beam.get_prompt()
lora_request_item = beam.lora_request
request_id_item = f"{request_id_batch}-beam-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.engine_client.generate(
individual_prompt,
beam_search_params,
prompt_item,
sampling_params,
request_id_item,
lora_request=lora_req,
lora_request=lora_request_item,
trace_headers=trace_headers,
)
)
......@@ -406,6 +377,7 @@ class OpenAIServing:
logprobs_entry = result.outputs[0].logprobs[0]
completed.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output
else current_beam.tokens,
......@@ -433,12 +405,11 @@ class OpenAIServing:
logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request,
cum_logprob=float(all_beams_logprob[idx]),
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)
......@@ -958,7 +929,7 @@ class OpenAIServing:
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
) -> list[ProcessorInputs]:
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
......@@ -971,7 +942,7 @@ class OpenAIServing:
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[TokPrompt]:
) -> list[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
......@@ -1004,7 +975,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer
......@@ -1052,13 +1023,13 @@ class OpenAIServing:
return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
def _extract_prompt_text(self, prompt: ProcessorInputs):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
def _extract_prompt_len(self, prompt: ProcessorInputs):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn(
......@@ -1088,16 +1059,14 @@ class OpenAIServing:
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokPrompt,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len
prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority
sub_request = 0
......@@ -1112,26 +1081,13 @@ class OpenAIServing:
lora_request=lora_request,
)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
generator = self.engine_client.generate(
engine_request,
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
)
async for res in generator:
......@@ -1154,11 +1110,11 @@ class OpenAIServing:
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
......@@ -1166,8 +1122,6 @@ class OpenAIServing:
context.chat_template,
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
max_model_len,
......@@ -1184,7 +1138,7 @@ class OpenAIServing:
def _log_inputs(
self,
request_id: str,
inputs: PromptType | TokPrompt,
inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
......
......@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime
from vllm.renderers.inputs.preprocess import parse_model_prompt
logger = init_logger(__name__)
......@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing):
Yields:
StreamingInput objects containing audio prompts for the engine
"""
model_config = self.model_config
renderer = self.renderer
# mypy is being stupid
# TODO(Patrick) - fix this
stream_input_iter = cast(
AsyncGenerator[PromptType, None],
self.model_cls.buffer_realtime_audio(
audio_stream, input_stream, self.model_config
audio_stream, input_stream, model_config
),
)
async for prompt in stream_input_iter:
yield StreamingInput(prompt=prompt)
parsed_prompt = parse_model_prompt(model_config, prompt)
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
yield StreamingInput(prompt=engine_prompt)
......@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Final, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
......@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext):
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.chat_template_content_format: Final = chat_template_content_format
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
......
......@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
......@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input(
self,
engine_prompt: TokPrompt,
engine_prompt: ProcessorInputs,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
prompt_len = self._extract_prompt_len(engine_prompt)
......@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
tok_params = request.build_tok_params(self.model_config)
trace_headers = (
None
......@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
tok_params=tok_params,
context=context,
lora_request=lora_request,
priority=request.priority,
......@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing):
messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
engine_prompt = token_inputs(prompt_token_ids)
# Add cache_salt if provided in the request
if request.cache_salt is not None:
......
......@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationSegment,
TranslationStreamResponse,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
......@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing):
return
try:
from vllm.sampling_params import SamplingParams
warmup_start = time.perf_counter()
logger.info("Warming up multimodal input processor...")
......@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt="",
to_language=None,
)
# Create minimal sampling params
dummy_params = SamplingParams(
max_tokens=1,
temperature=0.0,
skip_clone=True, # Internal warmup, safe to skip clone
)
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
# Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization
_ = self.input_processor.process_inputs(
request_id="warmup",
prompt=dummy_prompt,
params=dummy_params,
)
_ = self.renderer.render_cmpl([parsed_prompt])
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
......@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing):
self,
request: SpeechToTextRequest,
audio_data: bytes,
) -> tuple[list[PromptType], float]:
) -> tuple[list[ProcessorInputs], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
......@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing):
and duration > self.asr_config.max_audio_clip_s
)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
parsed_prompts: list[DictPrompt] = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
......@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing):
request_prompt=request.prompt,
to_language=to_language,
)
parsed_prompt: DictPrompt
if request.response_format == "verbose_json":
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
parsed_prompt = parse_enc_dec_prompt(prompt)
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
else:
parsed_prompt = parse_model_prompt(self.model_config, prompt)
parsed_prompts.append(parsed_prompt)
prompts.append(prompt)
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
return prompts, duration
return engine_prompts, duration
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
......@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
prompts, duration_s = await self._preprocess_speech_to_text(
engine_prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
)
......@@ -445,28 +441,35 @@ class OpenAISpeechToText(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None:
default_max_tokens = self.model_config.max_model_len
else:
default_max_tokens = min(
self.model_config.max_model_len, request.max_completion_tokens
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens,
0,
self.default_sampling_params,
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
max_tokens,
self.default_sampling_params,
)
if request.response_format == "verbose_json":
sampling_params.logprobs = 1
list_result_generator = []
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}_{i}"
self._log_inputs(
request_id,
# It will not display special tokens like <|startoftranscript|>
request.prompt,
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
......@@ -477,25 +480,15 @@ class OpenAISpeechToText(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
list_result_generator = []
for i, prompt in enumerate(prompts):
request_id_item = f"{request_id}_{i}"
engine_request = self.input_processor.process_inputs(
request_id_item,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=0,
)
list_result_generator.append(
self.engine_client.generate(
engine_request,
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
)
)
list_result_generator.append(generator)
except ValueError as e:
return self.create_error_response(e)
......
......@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness
......@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing):
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk
chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
chunk_engine_prompt = token_inputs(chunk_tokens)
# Log the chunk
self._log_inputs(
......@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=ctx.request.priority,
)
......@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokPrompt,
engine_prompt: ProcessorInputs,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
......@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=ctx.request.priority,
)
......
......@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs import PromptType
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
engine_prompts: Sequence[PromptType | TokPrompt]
engine_prompts: Sequence[ProcessorInputs]
if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None:
raise ValueError(
......@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing):
else:
pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params,
request_id_item,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
......
......@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
get_score_prompt,
validate_score_input,
)
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
......@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
engine_prompts: list[ProcessorInputs] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
)
# Schedule the request and get the result generator.
......@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(
request_id_item,
input_texts[i],
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)
......@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
engine_prompts: list[ProcessorInputs] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
token_inputs(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
)
# Schedule the request and get the result generator.
......@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(
request_id_item,
input_texts[i],
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)
......
......@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
......@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs(
request_id,
TokensPrompt(prompt_token_ids=request.token_ids),
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
......@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
result_generator = self.engine_client.generate(
engine_request,
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
tokenization_kwargs=tokenization_kwargs,
)
except ValueError as e:
......
......@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse,
TokenizerInfoResponse,
)
from vllm.inputs import TokensPrompt
from vllm.inputs import TokensPrompt, token_inputs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
......@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs(
request_id,
TokensPrompt(prompt_token_ids=request.tokens),
token_inputs(request.tokens),
params=None,
lora_request=lora_request,
)
......
......@@ -187,6 +187,9 @@ class _InputOptions(TypedDict):
Additional options available to all input types.
"""
arrival_time: NotRequired[float]
"""The time when the input was received (before rendering)."""
cache_salt: NotRequired[str]
"""Optional cache salt to be used for prefix caching."""
......@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict):
decoder_prompt: DecoderInputs
"""The inputs for the decoder portion."""
arrival_time: NotRequired[float]
"""The time when the input was received (before rendering)."""
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment