Unverified Commit 574fe752 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c61a98f5
......@@ -19,11 +19,9 @@ from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike
......@@ -41,7 +39,6 @@ from .data import (
TextPrompt,
TokenInputs,
TokensPrompt,
embeds_inputs,
token_inputs,
)
......@@ -83,7 +80,7 @@ class InputPreprocessor:
**(tokenization_kwargs or {})
)
tok_prompt = renderer.tokenize_prompt(
tok_prompt = renderer._tokenize_singleton_prompt(
TextPrompt(prompt=prompt),
tok_params,
)
......@@ -103,17 +100,10 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
mm_processor = self.renderer.get_mm_processor()
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
mm_items = mm_processor.info.parse_mm_data(mm_data)
return mm_processor.apply(
return self.renderer._process_multimodal(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
mm_data,
mm_processor_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
......@@ -122,31 +112,7 @@ class InputPreprocessor:
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = parsed_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
)
return self.renderer._process_embeds(parsed_content)
def _truncate_inputs(
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
......@@ -157,7 +123,7 @@ class InputPreprocessor:
**(tokenization_kwargs or {})
)
tok_prompt = renderer.tokenize_prompt(
tok_prompt = renderer._tokenize_singleton_prompt(
TokensPrompt(prompt_token_ids=inputs),
tok_params,
)
......@@ -168,8 +134,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> TokenInputs | MultiModalInputs:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs
......@@ -182,11 +146,13 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
mm_uuids=parsed_content.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := parsed_content.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -196,8 +162,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> TokenInputs | MultiModalInputs:
prompt_text = parsed_content["prompt"]
......@@ -208,7 +172,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
......@@ -217,6 +180,8 @@ class InputPreprocessor:
)
inputs = token_inputs(prompt_token_ids)
inputs["prompt"] = prompt_text
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -227,8 +192,6 @@ class InputPreprocessor:
self,
prompt: EncoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs: ...
@overload
......@@ -236,8 +199,6 @@ class InputPreprocessor:
self,
prompt: DecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs: ...
@overload
......@@ -245,16 +206,12 @@ class InputPreprocessor:
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ...
def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
......@@ -271,16 +228,12 @@ class InputPreprocessor:
return self._process_embeds(prompt) # type: ignore[arg-type]
if "prompt_token_ids" in prompt:
return self._process_tokens(
prompt, # type: ignore[arg-type]
mm_uuids=mm_uuids,
)
return self._process_tokens(prompt) # type: ignore[arg-type]
if "prompt" in prompt:
return self._process_text(
prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
assert_never(prompt) # type: ignore[arg-type]
......@@ -289,8 +242,6 @@ class InputPreprocessor:
self,
prompt: EncoderDecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
......@@ -314,7 +265,6 @@ class InputPreprocessor:
encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
),
decoder_inputs=(
None
......@@ -331,8 +281,6 @@ class InputPreprocessor:
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
......@@ -350,41 +298,23 @@ class InputPreprocessor:
return self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def _preprocess(
def preprocess(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt(
parse_enc_dec_prompt(prompt),
tokenization_kwargs,
mm_uuids=mm_uuids,
)
return self._process_decoder_only_prompt(
parse_dec_only_prompt(prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def preprocess(
self,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
self.renderer.update_mm_cache_stats()
return res
......@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
......@@ -810,13 +809,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_token_id = vocab[audio_token]
audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data()
......@@ -836,17 +829,12 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [audio_token_id] * num_features
return [
PromptReplacement(
modality="audio",
target=audio_token,
target=[audio_token_id],
replacement=get_replacement_qwen2_audio,
)
]
......
......@@ -59,7 +59,6 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -187,8 +186,10 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token
audio_bos_token = hf_processor.audio_bos_token
audio_eos_token = hf_processor.audio_eos_token
return audio_token * num_audios
return (audio_bos_token + audio_token + audio_eos_token) * num_audios
def get_dummy_mm_data(
self,
......@@ -262,17 +263,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
audio_token_id = vocab[audio_token]
audio_bos_id = vocab[audio_bos_token]
audio_eos_id = vocab[audio_eos_token]
audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
......@@ -303,17 +294,12 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
"to be represented inside the model"
)
audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
[audio_bos_id] + audio_tokens + [audio_eos_id],
embed_token_id=audio_token_id,
)
return [audio_token_id] * num_features
return [
PromptReplacement(
modality="audio",
target=audio_token,
target=[audio_token_id],
replacement=get_replacement_qwen2_audio,
)
]
......
......@@ -1843,15 +1843,18 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
if isinstance(decoder_prompt_raw, str):
decoder_prompt_text = decoder_prompt_raw
decoder_prompt_ids = tokenizer.encode(
decoder_prompt_raw, add_special_tokens=False
)
else:
decoder_prompt_text = None
decoder_prompt_ids = decoder_prompt_raw
return mm_enc_dec_inputs(
encoder_inputs,
decoder_prompt_ids,
decoder_prompt=decoder_prompt_text,
)
def apply(
......
......@@ -19,7 +19,6 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.selector import AttentionSelectorConfig
......@@ -569,7 +568,7 @@ class Platform:
@classmethod
def validate_request(
cls,
prompt: "PromptType | DictPrompt | TokPrompt",
prompt: "PromptType | ProcessorInputs",
params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs",
) -> None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs import (
EmbedsInputs,
EmbedsPrompt,
EncoderDecoderInputs,
ProcessorInputs,
SingletonInputs,
TextPrompt,
TokenInputs,
TokensPrompt,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
......@@ -20,6 +32,8 @@ from .inputs import (
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt,
SingletonDictPrompt,
SingletonTokPrompt,
TokPrompt,
)
from .inputs.preprocess import extract_target_prompt
......@@ -32,6 +46,12 @@ if TYPE_CHECKING:
ConversationMessage,
)
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__)
......@@ -79,6 +99,10 @@ class BaseRenderer(ABC, Generic[_T]):
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
if tokenizer is None:
......@@ -284,17 +308,79 @@ class BaseRenderer(ABC, Generic[_T]):
return prompt
@overload
def _tokenize_singleton_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def _tokenize_singleton_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
def _tokenize_singleton_prompt(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
@overload
async def _tokenize_singleton_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
async def _tokenize_singleton_prompt_async(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self.tokenize_prompt(prompt["encoder_prompt"], params),
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self.tokenize_prompt(prompt["decoder_prompt"], params)
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
),
)
......@@ -309,11 +395,13 @@ class BaseRenderer(ABC, Generic[_T]):
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
else self._tokenize_singleton_prompt_async(
prompt["decoder_prompt"], params
)
),
)
......@@ -322,27 +410,6 @@ class BaseRenderer(ABC, Generic[_T]):
decoder_prompt=dec_prompt,
)
@overload
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
def tokenize_prompt(
self,
prompt: DictPrompt,
......@@ -351,17 +418,7 @@ class BaseRenderer(ABC, Generic[_T]):
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
return self._tokenize_singleton_prompt(prompt, params)
def tokenize_prompts(
self,
......@@ -370,27 +427,6 @@ class BaseRenderer(ABC, Generic[_T]):
) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
@overload
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
async def tokenize_prompt_async(
self,
prompt: DictPrompt,
......@@ -399,17 +435,7 @@ class BaseRenderer(ABC, Generic[_T]):
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
return await self._tokenize_singleton_prompt_async(prompt, params)
async def tokenize_prompts_async(
self,
......@@ -423,7 +449,7 @@ class BaseRenderer(ABC, Generic[_T]):
# Step 3: Add extra keys to the prompts
def _apply_prompt_extras(
self,
prompts: Sequence[DictPrompt | TokPrompt],
prompts: Sequence[TokPrompt],
prompt_extras: dict[str, Any] | None,
):
if not prompt_extras:
......@@ -433,6 +459,200 @@ class BaseRenderer(ABC, Generic[_T]):
target_prompt = extract_target_prompt(self.model_config, prompt)
target_prompt.update(prompt_extras) # type: ignore[arg-type]
# Step 4: Convert to engine inputs
def _validate_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None",
) -> None:
if mm_uuids is None:
mm_uuids = {}
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in `mm_items`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = mm_items.get(modality) or list[Any]()
uuid_items = mm_uuids.get(modality) or list[str | None]()
if isinstance(uuid_items, str):
uuid_items = [uuid_items]
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
def _process_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None",
mm_req_id: str,
):
model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
if (
model_config.multimodal_config
and model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.config.cache_config.enable_prefix_caching
):
mm_uuids = {
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
self._validate_mm_uuids(mm_data, mm_items, mm_uuids)
return mm_uuids
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
def _process_multimodal(
self,
prompt: list[int] | str,
mm_data: "MultiModalDataDict",
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
mm_uuids: "MultiModalUUIDDict | None",
) -> "MultiModalInputs":
from vllm.multimodal.processing.context import set_request_id
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuids = self._process_mm_uuids(mm_data, mm_items, mm_uuids, mm_req_id)
with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
self.update_mm_cache_stats()
return mm_inputs
def _process_tokens(
self,
prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs":
prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs
if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_embeds(
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = prompt["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"),
)
def _process_singleton(
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type]
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(
self, prompt: TokPrompt, arrival_time: float
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
else:
engine_prompt = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time
return engine_prompt
# Top-level methods
def render_cmpl(
self,
......@@ -441,6 +661,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
......@@ -449,8 +671,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return tok_prompts
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
async def render_cmpl_async(
self,
......@@ -459,6 +680,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_cmpl_tok_params
......@@ -467,8 +690,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return tok_prompts
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
def render_chat(
self,
......@@ -478,6 +700,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
......@@ -496,8 +720,11 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return out_conversations, tok_prompts
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
async def render_chat_async(
self,
......@@ -507,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
*,
prompt_extras: dict[str, Any] | None = None,
):
arrival_time = time.time()
if tok_params is None:
tok_params = self.default_chat_tok_params
......@@ -525,5 +754,8 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return out_conversations, tok_prompts
eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
......@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonPrompt,
TextPrompt,
......@@ -115,7 +116,7 @@ that has been standardized into a dictionary.
"""
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
"""
Parse a prompt for a decoder-only model and normalize it to a dictionary.
"""
......@@ -144,7 +145,7 @@ def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
......@@ -166,7 +167,7 @@ def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
......@@ -195,13 +196,13 @@ def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
"""
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else:
enc_prompt = prompt
dec_prompt = None
......@@ -235,21 +236,23 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object):
def extract_prompt_components(
model_config: "ModelConfig",
prompt: object,
prompt: PromptType | ProcessorInputs,
) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt)
return PromptComponents(
text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
token_ids=target_prompt.get("prompt_token_ids"),
embeds=target_prompt.get("prompt_embeds"),
)
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
def extract_prompt_len(
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
):
target_prompt = extract_target_prompt(model_config, prompt)
return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
target_prompt.get("prompt_token_ids"),
target_prompt.get("prompt_embeds"),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Sequence
from typing import Any, TypeVar, overload
from tqdm.auto import tqdm
_T = TypeVar("_T", bound=Iterable)
@overload
def maybe_tqdm(
it: Sequence[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Sequence[_T]: ...
@overload
def maybe_tqdm(
it: Iterable[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Iterable[_T]: ...
def maybe_tqdm(
it: Iterable[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Iterable[_T]:
if not use_tqdm:
return it
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
return tqdm_func(it, **tqdm_kwargs)
......@@ -20,7 +20,7 @@ from vllm.distributed.weight_transfer.base import (
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.inputs import PromptType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
......@@ -28,7 +28,6 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import merge_kwargs, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
......@@ -290,8 +289,7 @@ class AsyncLLM(EngineClient):
request_id: str,
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| ProcessorInputs
| AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
......@@ -301,6 +299,7 @@ class AsyncLLM(EngineClient):
priority: int = 0,
data_parallel_rank: int | None = None,
prompt_text: str | None = None,
reasoning_ended: bool | None = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
......@@ -336,6 +335,9 @@ class AsyncLLM(EngineClient):
)
if isinstance(prompt, AsyncGenerator):
if reasoning_ended is not None:
raise NotImplementedError
# Streaming input case.
return await self._add_streaming_input_request(
request_id,
......@@ -359,10 +361,6 @@ class AsyncLLM(EngineClient):
"latter will be used, and the former will be ignored."
)
else:
if prompt_text is not None:
raise ValueError(
"should only provide prompt_text with EngineCoreRequest"
)
request = self.input_processor.process_inputs(
request_id,
prompt,
......@@ -377,6 +375,9 @@ class AsyncLLM(EngineClient):
)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
if reasoning_ended is not None:
request.reasoning_ended = reasoning_ended
self.input_processor.assign_request_id(request)
# We start the output_handler on the first call to add_request() so
......@@ -536,8 +537,7 @@ class AsyncLLM(EngineClient):
self,
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| ProcessorInputs
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
......@@ -548,6 +548,7 @@ class AsyncLLM(EngineClient):
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
......@@ -576,6 +577,7 @@ class AsyncLLM(EngineClient):
priority=priority,
data_parallel_rank=data_parallel_rank,
prompt_text=prompt_text,
reasoning_ended=reasoning_ended,
)
# The output_handler task pushes items into the queue.
......@@ -770,13 +772,14 @@ class AsyncLLM(EngineClient):
async def encode(
self,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType | ProcessorInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
......@@ -802,6 +805,7 @@ class AsyncLLM(EngineClient):
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
reasoning_ended=reasoning_ended,
)
# The output_handler task pushes items into the queue.
......
......@@ -3,7 +3,7 @@
import time
from collections.abc import Mapping
from typing import Any, Literal, cast
from typing import Any, Literal
import vllm.envs as envs
from vllm.config import VllmConfig
......@@ -11,7 +11,6 @@ from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
)
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
......@@ -20,22 +19,16 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.jsontree import json_iter_leaves
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
......@@ -133,81 +126,6 @@ class InputProcessor:
f"but got {type(params).__name__}"
)
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.renderer.get_mm_processor()
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if not isinstance(prompt, dict):
return
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
if not mm_data and not mm_uuids:
return
mm_data_parsed = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
mm_uuids_parsed = {
k: [v] if isinstance(v, str) else v
for k, v in mm_uuids.items()
if v is not None
}
# NOTE: Include the keys corresponding to `None`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = cast(
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
)
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
else:
if len(uuid_items) == 0:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
self._validate_singleton_mm_uuids(dec_prompt)
else:
self._validate_singleton_mm_uuids(prompt)
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None:
return
......@@ -227,47 +145,6 @@ class InputProcessor:
"[lora_path]` to use the LoRA tokenizer."
)
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if not isinstance(prompt, dict):
return None
return prompt.get("multi_modal_data")
def _extract_mm_data(
self, prompt: PromptType | DictPrompt | TokPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
prompt: PromptType | DictPrompt | TokPrompt,
) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
mm_data = self._extract_mm_data(prompt)
if not mm_data:
return None
mm_items = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
return {
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
def _get_mm_identifier(
self,
mm_hash: str,
......@@ -309,7 +186,7 @@ class InputProcessor:
def process_inputs(
self,
request_id: str,
prompt: PromptType | DictPrompt | TokPrompt,
prompt: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
......@@ -333,43 +210,18 @@ class InputProcessor:
f"is out of range [0, {num_ranks})."
)
if arrival_time is None:
arrival_time = time.time()
if isinstance(prompt, dict) and "type" in prompt:
if arrival_time is None:
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (
self.model_config.multimodal_config
and self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching
):
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_mm_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else:
mm_uuids = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
with set_request_id(request_id), set_default_torch_num_threads():
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform
......
......@@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
......@@ -22,7 +22,6 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
......@@ -220,7 +219,7 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
prompt: EngineCoreRequest | PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
......@@ -228,7 +227,7 @@ class LLMEngine:
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
prompt_text: str | None = None,
) -> None:
) -> str:
# Validate the request_id type.
if not isinstance(request_id, str):
raise TypeError(f"request_id must be a string, got {type(request_id)}")
......@@ -243,7 +242,6 @@ class LLMEngine:
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
request_id,
prompt,
......@@ -259,6 +257,8 @@ class LLMEngine:
self.input_processor.assign_request_id(request)
req_id = request.request_id
# Use cloned params that may have been updated in process_inputs()
params = request.params
......@@ -269,7 +269,7 @@ class LLMEngine:
self.output_processor.add_request(request, prompt_text, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
return req_id
# Fan out child requests (for n>1).
parent_req = ParentRequest(request)
......@@ -286,6 +286,8 @@ class LLMEngine:
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
return req_id
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment