Unverified Commit 574fe752 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c61a98f5
...@@ -19,11 +19,9 @@ from vllm.renderers import BaseRenderer, renderer_from_config ...@@ -19,11 +19,9 @@ from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import ( from vllm.renderers.inputs import (
DecoderDictPrompt, DecoderDictPrompt,
DecoderOnlyDictPrompt, DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt, EncoderDecoderDictPrompt,
EncoderDictPrompt, EncoderDictPrompt,
SingletonDictPrompt, SingletonDictPrompt,
TokPrompt,
) )
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -41,7 +39,6 @@ from .data import ( ...@@ -41,7 +39,6 @@ from .data import (
TextPrompt, TextPrompt,
TokenInputs, TokenInputs,
TokensPrompt, TokensPrompt,
embeds_inputs,
token_inputs, token_inputs,
) )
...@@ -83,7 +80,7 @@ class InputPreprocessor: ...@@ -83,7 +80,7 @@ class InputPreprocessor:
**(tokenization_kwargs or {}) **(tokenization_kwargs or {})
) )
tok_prompt = renderer.tokenize_prompt( tok_prompt = renderer._tokenize_singleton_prompt(
TextPrompt(prompt=prompt), TextPrompt(prompt=prompt),
tok_params, tok_params,
) )
...@@ -103,17 +100,10 @@ class InputPreprocessor: ...@@ -103,17 +100,10 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
mm_processor = self.renderer.get_mm_processor() return self.renderer._process_multimodal(
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
mm_items = mm_processor.info.parse_mm_data(mm_data)
return mm_processor.apply(
prompt, prompt,
mm_items, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -122,31 +112,7 @@ class InputPreprocessor: ...@@ -122,31 +112,7 @@ class InputPreprocessor:
self, self,
parsed_content: EmbedsPrompt, parsed_content: EmbedsPrompt,
) -> EmbedsInputs: ) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds: return self.renderer._process_embeds(parsed_content)
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = parsed_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
)
def _truncate_inputs( def _truncate_inputs(
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
...@@ -157,7 +123,7 @@ class InputPreprocessor: ...@@ -157,7 +123,7 @@ class InputPreprocessor:
**(tokenization_kwargs or {}) **(tokenization_kwargs or {})
) )
tok_prompt = renderer.tokenize_prompt( tok_prompt = renderer._tokenize_singleton_prompt(
TokensPrompt(prompt_token_ids=inputs), TokensPrompt(prompt_token_ids=inputs),
tok_params, tok_params,
) )
...@@ -168,8 +134,6 @@ class InputPreprocessor: ...@@ -168,8 +134,6 @@ class InputPreprocessor:
self, self,
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> TokenInputs | MultiModalInputs: ) -> TokenInputs | MultiModalInputs:
prompt_token_ids = self._truncate_inputs( prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs parsed_content["prompt_token_ids"], tokenization_kwargs
...@@ -182,11 +146,13 @@ class InputPreprocessor: ...@@ -182,11 +146,13 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {}, parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=parsed_content.get("multi_modal_uuids"),
) )
else: else:
inputs = token_inputs(prompt_token_ids) inputs = token_inputs(prompt_token_ids)
if prompt_text := parsed_content.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -196,8 +162,6 @@ class InputPreprocessor: ...@@ -196,8 +162,6 @@ class InputPreprocessor:
self, self,
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> TokenInputs | MultiModalInputs: ) -> TokenInputs | MultiModalInputs:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -208,7 +172,6 @@ class InputPreprocessor: ...@@ -208,7 +172,6 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {}, parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
...@@ -217,6 +180,8 @@ class InputPreprocessor: ...@@ -217,6 +180,8 @@ class InputPreprocessor:
) )
inputs = token_inputs(prompt_token_ids) inputs = token_inputs(prompt_token_ids)
inputs["prompt"] = prompt_text
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -227,8 +192,6 @@ class InputPreprocessor: ...@@ -227,8 +192,6 @@ class InputPreprocessor:
self, self,
prompt: EncoderDictPrompt, prompt: EncoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs: ... ) -> EncoderInputs: ...
@overload @overload
...@@ -236,8 +199,6 @@ class InputPreprocessor: ...@@ -236,8 +199,6 @@ class InputPreprocessor:
self, self,
prompt: DecoderDictPrompt, prompt: DecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs: ... ) -> DecoderInputs: ...
@overload @overload
...@@ -245,16 +206,12 @@ class InputPreprocessor: ...@@ -245,16 +206,12 @@ class InputPreprocessor:
self, self,
prompt: DecoderOnlyDictPrompt, prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ... ) -> DecoderOnlyInputs: ...
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonDictPrompt, prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
...@@ -271,16 +228,12 @@ class InputPreprocessor: ...@@ -271,16 +228,12 @@ class InputPreprocessor:
return self._process_embeds(prompt) # type: ignore[arg-type] return self._process_embeds(prompt) # type: ignore[arg-type]
if "prompt_token_ids" in prompt: if "prompt_token_ids" in prompt:
return self._process_tokens( return self._process_tokens(prompt) # type: ignore[arg-type]
prompt, # type: ignore[arg-type]
mm_uuids=mm_uuids,
)
if "prompt" in prompt: if "prompt" in prompt:
return self._process_text( return self._process_text(
prompt, # type: ignore[arg-type] prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
assert_never(prompt) # type: ignore[arg-type] assert_never(prompt) # type: ignore[arg-type]
...@@ -289,8 +242,6 @@ class InputPreprocessor: ...@@ -289,8 +242,6 @@ class InputPreprocessor:
self, self,
prompt: EncoderDecoderDictPrompt, prompt: EncoderDecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -314,7 +265,6 @@ class InputPreprocessor: ...@@ -314,7 +265,6 @@ class InputPreprocessor:
encoder_inputs=self._prompt_to_llm_inputs( encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt, encoder_prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
), ),
decoder_inputs=( decoder_inputs=(
None None
...@@ -331,8 +281,6 @@ class InputPreprocessor: ...@@ -331,8 +281,6 @@ class InputPreprocessor:
self, self,
prompt: DecoderOnlyDictPrompt, prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
...@@ -350,41 +298,23 @@ class InputPreprocessor: ...@@ -350,41 +298,23 @@ class InputPreprocessor:
return self._prompt_to_llm_inputs( return self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
def _preprocess( def preprocess(
self, self,
prompt: PromptType | DictPrompt | TokPrompt, prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder. # input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
parse_enc_dec_prompt(prompt), parse_enc_dec_prompt(prompt),
tokenization_kwargs, tokenization_kwargs,
mm_uuids=mm_uuids,
) )
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
parse_dec_only_prompt(prompt), parse_dec_only_prompt(prompt),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
def preprocess(
self,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
self.renderer.update_mm_cache_stats()
return res
...@@ -48,7 +48,6 @@ from vllm.multimodal.processing import ( ...@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails,
) )
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
...@@ -810,13 +809,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): ...@@ -810,13 +809,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() audio_token_id = processor.audio_token_id
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_token_id = vocab[audio_token]
out_mm_data = out_mm_kwargs.get_data() out_mm_data = out_mm_kwargs.get_data()
...@@ -836,17 +829,12 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): ...@@ -836,17 +829,12 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor" assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0] num_features = audio_embeds.shape[0]
audio_tokens = [audio_token_id] * num_features return [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [ return [
PromptReplacement( PromptReplacement(
modality="audio", modality="audio",
target=audio_token, target=[audio_token_id],
replacement=get_replacement_qwen2_audio, replacement=get_replacement_qwen2_audio,
) )
] ]
......
...@@ -59,7 +59,6 @@ from vllm.multimodal.processing import ( ...@@ -59,7 +59,6 @@ from vllm.multimodal.processing import (
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -187,8 +186,10 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn ...@@ -187,8 +186,10 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token audio_token = hf_processor.audio_token
audio_bos_token = hf_processor.audio_bos_token
audio_eos_token = hf_processor.audio_eos_token
return audio_token * num_audios return (audio_bos_token + audio_token + audio_eos_token) * num_audios
def get_dummy_mm_data( def get_dummy_mm_data(
self, self,
...@@ -262,17 +263,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing ...@@ -262,17 +263,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() audio_token_id = processor.audio_token_id
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
audio_token_id = vocab[audio_token]
audio_bos_id = vocab[audio_bos_token]
audio_eos_id = vocab[audio_eos_token]
out_mm_data = out_mm_kwargs.get_data() out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask") feature_attention_mask = out_mm_data.get("feature_attention_mask")
...@@ -303,17 +294,12 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing ...@@ -303,17 +294,12 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
"to be represented inside the model" "to be represented inside the model"
) )
audio_tokens = [audio_token_id] * num_features return [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id(
[audio_bos_id] + audio_tokens + [audio_eos_id],
embed_token_id=audio_token_id,
)
return [ return [
PromptReplacement( PromptReplacement(
modality="audio", modality="audio",
target=audio_token, target=[audio_token_id],
replacement=get_replacement_qwen2_audio, replacement=get_replacement_qwen2_audio,
) )
] ]
......
...@@ -1843,15 +1843,18 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1843,15 +1843,18 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items) decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
if isinstance(decoder_prompt_raw, str): if isinstance(decoder_prompt_raw, str):
decoder_prompt_text = decoder_prompt_raw
decoder_prompt_ids = tokenizer.encode( decoder_prompt_ids = tokenizer.encode(
decoder_prompt_raw, add_special_tokens=False decoder_prompt_raw, add_special_tokens=False
) )
else: else:
decoder_prompt_text = None
decoder_prompt_ids = decoder_prompt_raw decoder_prompt_ids = decoder_prompt_raw
return mm_enc_dec_inputs( return mm_enc_dec_inputs(
encoder_inputs, encoder_inputs,
decoder_prompt_ids, decoder_prompt_ids,
decoder_prompt=decoder_prompt_text,
) )
def apply( def apply(
......
...@@ -19,7 +19,6 @@ if TYPE_CHECKING: ...@@ -19,7 +19,6 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
...@@ -569,7 +568,7 @@ class Platform: ...@@ -569,7 +568,7 @@ class Platform:
@classmethod @classmethod
def validate_request( def validate_request(
cls, cls,
prompt: "PromptType | DictPrompt | TokPrompt", prompt: "PromptType | ProcessorInputs",
params: "SamplingParams | PoolingParams", params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs", processed_inputs: "ProcessorInputs",
) -> None: ) -> None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import (
EmbedsInputs,
EmbedsPrompt,
EncoderDecoderInputs,
ProcessorInputs,
SingletonInputs,
TextPrompt,
TokenInputs,
TokensPrompt,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
...@@ -20,6 +32,8 @@ from .inputs import ( ...@@ -20,6 +32,8 @@ from .inputs import (
DictPrompt, DictPrompt,
EncoderDecoderDictPrompt, EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt, EncoderDecoderTokPrompt,
SingletonDictPrompt,
SingletonTokPrompt,
TokPrompt, TokPrompt,
) )
from .inputs.preprocess import extract_target_prompt from .inputs.preprocess import extract_target_prompt
...@@ -32,6 +46,12 @@ if TYPE_CHECKING: ...@@ -32,6 +46,12 @@ if TYPE_CHECKING:
ConversationMessage, ConversationMessage,
) )
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -79,6 +99,10 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -79,6 +99,10 @@ class BaseRenderer(ABC, Generic[_T]):
if mm_processor_cache: if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats() self._mm_cache_stats = MultiModalCacheStats()
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
def get_tokenizer(self) -> _T: def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer tokenizer = self.tokenizer
if tokenizer is None: if tokenizer is None:
...@@ -284,17 +308,79 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -284,17 +308,79 @@ class BaseRenderer(ABC, Generic[_T]):
return prompt return prompt
@overload
def _tokenize_singleton_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def _tokenize_singleton_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
def _tokenize_singleton_prompt(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
@overload
async def _tokenize_singleton_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
async def _tokenize_singleton_prompt_async(
self,
prompt: SingletonDictPrompt,
params: TokenizeParams,
) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def _tokenize_enc_dec_prompt( def _tokenize_enc_dec_prompt(
self, self,
prompt: EncoderDecoderDictPrompt, prompt: EncoderDecoderDictPrompt,
params: TokenizeParams, params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = ( enc_prompt, dec_prompt = (
self.tokenize_prompt(prompt["encoder_prompt"], params), self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
( (
None None
if prompt["decoder_prompt"] is None if prompt["decoder_prompt"] is None
else self.tokenize_prompt(prompt["decoder_prompt"], params) else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
), ),
) )
...@@ -309,11 +395,13 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -309,11 +395,13 @@ class BaseRenderer(ABC, Generic[_T]):
params: TokenizeParams, params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather( enc_prompt, dec_prompt = await asyncio.gather(
self.tokenize_prompt_async(prompt["encoder_prompt"], params), self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
( (
asyncio.sleep(0) asyncio.sleep(0)
if prompt["decoder_prompt"] is None if prompt["decoder_prompt"] is None
else self.tokenize_prompt_async(prompt["decoder_prompt"], params) else self._tokenize_singleton_prompt_async(
prompt["decoder_prompt"], params
)
), ),
) )
...@@ -322,27 +410,6 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -322,27 +410,6 @@ class BaseRenderer(ABC, Generic[_T]):
decoder_prompt=dec_prompt, decoder_prompt=dec_prompt,
) )
@overload
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
def tokenize_prompt( def tokenize_prompt(
self, self,
prompt: DictPrompt, prompt: DictPrompt,
...@@ -351,17 +418,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -351,17 +418,7 @@ class BaseRenderer(ABC, Generic[_T]):
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type] return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: return self._tokenize_singleton_prompt(prompt, params)
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def tokenize_prompts( def tokenize_prompts(
self, self,
...@@ -370,27 +427,6 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -370,27 +427,6 @@ class BaseRenderer(ABC, Generic[_T]):
) -> list[TokPrompt]: ) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts] return [self.tokenize_prompt(prompt, params) for prompt in prompts]
@overload
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
async def tokenize_prompt_async( async def tokenize_prompt_async(
self, self,
prompt: DictPrompt, prompt: DictPrompt,
...@@ -399,17 +435,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -399,17 +435,7 @@ class BaseRenderer(ABC, Generic[_T]):
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type] return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: return await self._tokenize_singleton_prompt_async(prompt, params)
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
async def tokenize_prompts_async( async def tokenize_prompts_async(
self, self,
...@@ -423,7 +449,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -423,7 +449,7 @@ class BaseRenderer(ABC, Generic[_T]):
# Step 3: Add extra keys to the prompts # Step 3: Add extra keys to the prompts
def _apply_prompt_extras( def _apply_prompt_extras(
self, self,
prompts: Sequence[DictPrompt | TokPrompt], prompts: Sequence[TokPrompt],
prompt_extras: dict[str, Any] | None, prompt_extras: dict[str, Any] | None,
): ):
if not prompt_extras: if not prompt_extras:
...@@ -433,6 +459,200 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -433,6 +459,200 @@ class BaseRenderer(ABC, Generic[_T]):
target_prompt = extract_target_prompt(self.model_config, prompt) target_prompt = extract_target_prompt(self.model_config, prompt)
target_prompt.update(prompt_extras) # type: ignore[arg-type] target_prompt.update(prompt_extras) # type: ignore[arg-type]
# Step 4: Convert to engine inputs
def _validate_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None",
) -> None:
if mm_uuids is None:
mm_uuids = {}
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in `mm_items`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = mm_items.get(modality) or list[Any]()
uuid_items = mm_uuids.get(modality) or list[str | None]()
if isinstance(uuid_items, str):
uuid_items = [uuid_items]
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
def _process_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None",
mm_req_id: str,
):
model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
if (
model_config.multimodal_config
and model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.config.cache_config.enable_prefix_caching
):
mm_uuids = {
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
self._validate_mm_uuids(mm_data, mm_items, mm_uuids)
return mm_uuids
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
def _process_multimodal(
self,
prompt: list[int] | str,
mm_data: "MultiModalDataDict",
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
mm_uuids: "MultiModalUUIDDict | None",
) -> "MultiModalInputs":
from vllm.multimodal.processing.context import set_request_id
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuids = self._process_mm_uuids(mm_data, mm_items, mm_uuids, mm_req_id)
with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply(
prompt,
mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
self.update_mm_cache_stats()
return mm_inputs
def _process_tokens(
self,
prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs":
prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs
if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_embeds(
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
prompt_embeds = prompt["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"),
)
def _process_singleton(
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type]
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
)
def process_for_engine(
self, prompt: TokPrompt, arrival_time: float
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
else:
engine_prompt = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time
return engine_prompt
# Top-level methods # Top-level methods
def render_cmpl( def render_cmpl(
self, self,
...@@ -441,6 +661,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -441,6 +661,8 @@ class BaseRenderer(ABC, Generic[_T]):
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
arrival_time = time.time()
if tok_params is None: if tok_params is None:
tok_params = self.default_cmpl_tok_params tok_params = self.default_cmpl_tok_params
...@@ -449,8 +671,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -449,8 +671,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
return tok_prompts
async def render_cmpl_async( async def render_cmpl_async(
self, self,
...@@ -459,6 +680,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -459,6 +680,8 @@ class BaseRenderer(ABC, Generic[_T]):
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
arrival_time = time.time()
if tok_params is None: if tok_params is None:
tok_params = self.default_cmpl_tok_params tok_params = self.default_cmpl_tok_params
...@@ -467,8 +690,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -467,8 +690,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
return tok_prompts
def render_chat( def render_chat(
self, self,
...@@ -478,6 +700,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -478,6 +700,8 @@ class BaseRenderer(ABC, Generic[_T]):
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
arrival_time = time.time()
if tok_params is None: if tok_params is None:
tok_params = self.default_chat_tok_params tok_params = self.default_chat_tok_params
...@@ -496,8 +720,11 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -496,8 +720,11 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor eng_prompts = [
return out_conversations, tok_prompts self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
async def render_chat_async( async def render_chat_async(
self, self,
...@@ -507,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -507,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
arrival_time = time.time()
if tok_params is None: if tok_params is None:
tok_params = self.default_chat_tok_params tok_params = self.default_chat_tok_params
...@@ -525,5 +754,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -525,5 +754,8 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor eng_prompts = [
return out_conversations, tok_prompts self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
return out_conversations, eng_prompts
...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload ...@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import ( from vllm.inputs import (
EmbedsPrompt, EmbedsPrompt,
ExplicitEncoderDecoderPrompt, ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
TextPrompt, TextPrompt,
...@@ -115,7 +116,7 @@ that has been standardized into a dictionary. ...@@ -115,7 +116,7 @@ that has been standardized into a dictionary.
""" """
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt: def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
""" """
Parse a prompt for a decoder-only model and normalize it to a dictionary. Parse a prompt for a decoder-only model and normalize it to a dictionary.
""" """
...@@ -144,7 +145,7 @@ def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt: ...@@ -144,7 +145,7 @@ def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary") raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt: def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
if isinstance(prompt, str): if isinstance(prompt, str):
return TextPrompt(prompt=prompt) return TextPrompt(prompt=prompt)
...@@ -166,7 +167,7 @@ def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt: ...@@ -166,7 +167,7 @@ def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary") raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt: def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
if isinstance(prompt, str): if isinstance(prompt, str):
return TextPrompt(prompt=prompt) return TextPrompt(prompt=prompt)
...@@ -195,13 +196,13 @@ def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt: ...@@ -195,13 +196,13 @@ def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
raise TypeError("Prompt should be a string, list of tokens, or dictionary") raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt: def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
""" """
Parse a prompt for an encoder-decoder model and normalize it to a dictionary. Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
""" """
if isinstance(prompt, dict) and "encoder_prompt" in prompt: if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item] enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item] dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else: else:
enc_prompt = prompt enc_prompt = prompt
dec_prompt = None dec_prompt = None
...@@ -235,21 +236,23 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object): ...@@ -235,21 +236,23 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object):
def extract_prompt_components( def extract_prompt_components(
model_config: "ModelConfig", model_config: "ModelConfig",
prompt: object, prompt: PromptType | ProcessorInputs,
) -> PromptComponents: ) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt) target_prompt = extract_target_prompt(model_config, prompt)
return PromptComponents( return PromptComponents(
text=target_prompt.get("prompt"), text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type] token_ids=target_prompt.get("prompt_token_ids"),
embeds=target_prompt.get("prompt_embeds"), embeds=target_prompt.get("prompt_embeds"),
) )
def extract_prompt_len(model_config: "ModelConfig", prompt: object): def extract_prompt_len(
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
):
target_prompt = extract_target_prompt(model_config, prompt) target_prompt = extract_target_prompt(model_config, prompt)
return length_from_prompt_token_ids_or_embeds( return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type] target_prompt.get("prompt_token_ids"),
target_prompt.get("prompt_embeds"), target_prompt.get("prompt_embeds"),
) )
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Sequence
from typing import Any, TypeVar, overload
from tqdm.auto import tqdm
_T = TypeVar("_T", bound=Iterable)
@overload
def maybe_tqdm(
it: Sequence[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Sequence[_T]: ...
@overload
def maybe_tqdm(
it: Iterable[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Iterable[_T]: ...
def maybe_tqdm(
it: Iterable[_T],
*,
use_tqdm: bool | Callable[..., tqdm],
**tqdm_kwargs: Any,
) -> Iterable[_T]:
if not use_tqdm:
return it
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
return tqdm_func(it, **tqdm_kwargs)
...@@ -20,7 +20,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -20,7 +20,7 @@ from vllm.distributed.weight_transfer.base import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.inputs import PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -28,7 +28,6 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput ...@@ -28,7 +28,6 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import merge_kwargs, renderer_from_config from vllm.renderers import merge_kwargs, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
...@@ -290,8 +289,7 @@ class AsyncLLM(EngineClient): ...@@ -290,8 +289,7 @@ class AsyncLLM(EngineClient):
request_id: str, request_id: str,
prompt: EngineCoreRequest prompt: EngineCoreRequest
| PromptType | PromptType
| DictPrompt | ProcessorInputs
| TokPrompt
| AsyncGenerator[StreamingInput, None], | AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
...@@ -301,6 +299,7 @@ class AsyncLLM(EngineClient): ...@@ -301,6 +299,7 @@ class AsyncLLM(EngineClient):
priority: int = 0, priority: int = 0,
data_parallel_rank: int | None = None, data_parallel_rank: int | None = None,
prompt_text: str | None = None, prompt_text: str | None = None,
reasoning_ended: bool | None = None,
) -> RequestOutputCollector: ) -> RequestOutputCollector:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
...@@ -336,6 +335,9 @@ class AsyncLLM(EngineClient): ...@@ -336,6 +335,9 @@ class AsyncLLM(EngineClient):
) )
if isinstance(prompt, AsyncGenerator): if isinstance(prompt, AsyncGenerator):
if reasoning_ended is not None:
raise NotImplementedError
# Streaming input case. # Streaming input case.
return await self._add_streaming_input_request( return await self._add_streaming_input_request(
request_id, request_id,
...@@ -359,10 +361,6 @@ class AsyncLLM(EngineClient): ...@@ -359,10 +361,6 @@ class AsyncLLM(EngineClient):
"latter will be used, and the former will be ignored." "latter will be used, and the former will be ignored."
) )
else: else:
if prompt_text is not None:
raise ValueError(
"should only provide prompt_text with EngineCoreRequest"
)
request = self.input_processor.process_inputs( request = self.input_processor.process_inputs(
request_id, request_id,
prompt, prompt,
...@@ -377,6 +375,9 @@ class AsyncLLM(EngineClient): ...@@ -377,6 +375,9 @@ class AsyncLLM(EngineClient):
) )
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
if reasoning_ended is not None:
request.reasoning_ended = reasoning_ended
self.input_processor.assign_request_id(request) self.input_processor.assign_request_id(request)
# We start the output_handler on the first call to add_request() so # We start the output_handler on the first call to add_request() so
...@@ -536,8 +537,7 @@ class AsyncLLM(EngineClient): ...@@ -536,8 +537,7 @@ class AsyncLLM(EngineClient):
self, self,
prompt: EngineCoreRequest prompt: EngineCoreRequest
| PromptType | PromptType
| DictPrompt | ProcessorInputs
| TokPrompt
| AsyncGenerator[StreamingInput, None], | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
...@@ -548,6 +548,7 @@ class AsyncLLM(EngineClient): ...@@ -548,6 +548,7 @@ class AsyncLLM(EngineClient):
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: int | None = None, data_parallel_rank: int | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
""" """
Main function called by the API server to kick off a request Main function called by the API server to kick off a request
...@@ -576,6 +577,7 @@ class AsyncLLM(EngineClient): ...@@ -576,6 +577,7 @@ class AsyncLLM(EngineClient):
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
prompt_text=prompt_text, prompt_text=prompt_text,
reasoning_ended=reasoning_ended,
) )
# The output_handler task pushes items into the queue. # The output_handler task pushes items into the queue.
...@@ -770,13 +772,14 @@ class AsyncLLM(EngineClient): ...@@ -770,13 +772,14 @@ class AsyncLLM(EngineClient):
async def encode( async def encode(
self, self,
prompt: PromptType | DictPrompt | TokPrompt, prompt: PromptType | ProcessorInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
reasoning_ended: bool | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
""" """
Main function called by the API server to kick off a request Main function called by the API server to kick off a request
...@@ -802,6 +805,7 @@ class AsyncLLM(EngineClient): ...@@ -802,6 +805,7 @@ class AsyncLLM(EngineClient):
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
reasoning_ended=reasoning_ended,
) )
# The output_handler task pushes items into the queue. # The output_handler task pushes items into the queue.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal, cast from typing import Any, Literal
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -11,7 +11,6 @@ from vllm.inputs.data import ( ...@@ -11,7 +11,6 @@ from vllm.inputs.data import (
ProcessorInputs, ProcessorInputs,
PromptType, PromptType,
SingletonInputs, SingletonInputs,
SingletonPrompt,
) )
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
...@@ -20,22 +19,16 @@ from vllm.lora.request import LoRARequest ...@@ -20,22 +19,16 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec, MultiModalFeatureSpec,
MultiModalUUIDDict,
) )
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, renderer_from_config from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.jsontree import json_iter_leaves from vllm.utils.jsontree import json_iter_leaves
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -133,81 +126,6 @@ class InputProcessor: ...@@ -133,81 +126,6 @@ class InputProcessor:
f"but got {type(params).__name__}" f"but got {type(params).__name__}"
) )
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.renderer.get_mm_processor()
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if not isinstance(prompt, dict):
return
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
if not mm_data and not mm_uuids:
return
mm_data_parsed = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
mm_uuids_parsed = {
k: [v] if isinstance(v, str) else v
for k, v in mm_uuids.items()
if v is not None
}
# NOTE: Include the keys corresponding to `None`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = cast(
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
)
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
else:
if len(uuid_items) == 0:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
self._validate_singleton_mm_uuids(dec_prompt)
else:
self._validate_singleton_mm_uuids(prompt)
def _validate_lora(self, lora_request: LoRARequest | None) -> None: def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None: if lora_request is None:
return return
...@@ -227,47 +145,6 @@ class InputProcessor: ...@@ -227,47 +145,6 @@ class InputProcessor:
"[lora_path]` to use the LoRA tokenizer." "[lora_path]` to use the LoRA tokenizer."
) )
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if not isinstance(prompt, dict):
return None
return prompt.get("multi_modal_data")
def _extract_mm_data(
self, prompt: PromptType | DictPrompt | TokPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
prompt: PromptType | DictPrompt | TokPrompt,
) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
mm_data = self._extract_mm_data(prompt)
if not mm_data:
return None
mm_items = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
return {
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
def _get_mm_identifier( def _get_mm_identifier(
self, self,
mm_hash: str, mm_hash: str,
...@@ -309,7 +186,7 @@ class InputProcessor: ...@@ -309,7 +186,7 @@ class InputProcessor:
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
prompt: PromptType | DictPrompt | TokPrompt, prompt: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -333,43 +210,18 @@ class InputProcessor: ...@@ -333,43 +210,18 @@ class InputProcessor:
f"is out of range [0, {num_ranks})." f"is out of range [0, {num_ranks})."
) )
if arrival_time is None: if isinstance(prompt, dict) and "type" in prompt:
arrival_time = time.time() if arrival_time is None:
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
# Optionally generate multimodal hash overrides to avoid hashing processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (
self.model_config.multimodal_config
and self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching
):
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
else: else:
# Otherwise, use user-provided uuids as multimodal hash overrides if arrival_time is None:
# if provided. arrival_time = time.time()
self._validate_mm_uuids(prompt)
if isinstance(prompt, dict): processed_inputs = self.input_preprocessor.preprocess(
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else:
mm_uuids = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
with set_request_id(request_id), set_default_torch_num_threads():
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
...@@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig ...@@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -22,7 +22,6 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput ...@@ -22,7 +22,6 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import renderer_from_config from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
...@@ -220,7 +219,7 @@ class LLMEngine: ...@@ -220,7 +219,7 @@ class LLMEngine:
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt, prompt: EngineCoreRequest | PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -228,7 +227,7 @@ class LLMEngine: ...@@ -228,7 +227,7 @@ class LLMEngine:
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
prompt_text: str | None = None, prompt_text: str | None = None,
) -> None: ) -> str:
# Validate the request_id type. # Validate the request_id type.
if not isinstance(request_id, str): if not isinstance(request_id, str):
raise TypeError(f"request_id must be a string, got {type(request_id)}") raise TypeError(f"request_id must be a string, got {type(request_id)}")
...@@ -243,7 +242,6 @@ class LLMEngine: ...@@ -243,7 +242,6 @@ class LLMEngine:
"latter will be used, and the former will be ignored." "latter will be used, and the former will be ignored."
) )
else: else:
assert prompt_text is None
request = self.input_processor.process_inputs( request = self.input_processor.process_inputs(
request_id, request_id,
prompt, prompt,
...@@ -259,6 +257,8 @@ class LLMEngine: ...@@ -259,6 +257,8 @@ class LLMEngine:
self.input_processor.assign_request_id(request) self.input_processor.assign_request_id(request)
req_id = request.request_id
# Use cloned params that may have been updated in process_inputs() # Use cloned params that may have been updated in process_inputs()
params = request.params params = request.params
...@@ -269,7 +269,7 @@ class LLMEngine: ...@@ -269,7 +269,7 @@ class LLMEngine:
self.output_processor.add_request(request, prompt_text, None, 0) self.output_processor.add_request(request, prompt_text, None, 0)
# Add the request to EngineCore. # Add the request to EngineCore.
self.engine_core.add_request(request) self.engine_core.add_request(request)
return return req_id
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_req = ParentRequest(request) parent_req = ParentRequest(request)
...@@ -286,6 +286,8 @@ class LLMEngine: ...@@ -286,6 +286,8 @@ class LLMEngine:
# Add the request to EngineCore. # Add the request to EngineCore.
self.engine_core.add_request(child_request) self.engine_core.add_request(child_request)
return req_id
def step(self) -> list[RequestOutput | PoolingRequestOutput]: def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch: if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment