Unverified Commit d78fda7c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move Processor out of LLMEngine (#26165)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 73a99cc2
......@@ -37,6 +37,7 @@ from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -49,10 +50,13 @@ from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
get_cached_tokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.engine.processor import Processor
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
......@@ -312,6 +316,10 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
@property
def model_config(self):
return self.llm_engine.model_config
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer()
......@@ -324,6 +332,16 @@ class LLM:
else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def _get_processor(self) -> Processor:
if not hasattr(self, "_processor"):
vllm_config = self.llm_engine.vllm_config
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self._processor = Processor(vllm_config, tokenizer)
return self._processor
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
self.default_sampling_params = (
......@@ -1497,8 +1515,6 @@ class LLM:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
if isinstance(prompt, dict):
......@@ -1506,17 +1522,9 @@ class LLM:
prompt.get("multi_modal_data"),
prompt.get("multi_modal_uuids"))
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0,
......@@ -1557,22 +1565,58 @@ class LLM:
raise ValueError(f"Multi-modal data for {modality} is None"
f" but UUID is not provided")
def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest],
priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs)
processor = self._get_processor()
engine_request = processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return engine_request, tokenization_kwargs
def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
) -> None:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
engine_request, tokenization_kwargs = self._process_inputs(
request_id,
prompt,
params,
lora_request=lora_request,
priority=priority,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
)
def _run_engine(
......
......@@ -7,8 +7,7 @@ import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
TypeVar, Union)
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
import torch
from fastapi import Request
......@@ -69,6 +68,7 @@ from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
# yapf: enable
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import PromptComponents, get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
......@@ -140,12 +140,6 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
and "prompt_embeds" in prompt)
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional[torch.Tensor] = None
RequestT = TypeVar("RequestT", bound=AnyRequest)
......@@ -876,25 +870,23 @@ class OpenAIServing:
self,
request_id: str,
engine_prompt: PromptType,
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]],
priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""
using the Processor to process inputs for AsyncLLM
"""
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len,
sampling_params.truncate_prompt_tokens,
params.truncate_prompt_tokens,
tokenization_kwargs)
processor = await self._get_processor()
engine_request = processor.process_inputs(
request_id,
engine_prompt,
sampling_params,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
......@@ -973,25 +965,12 @@ class OpenAIServing:
def _get_prompt_components(
self,
inputs: Union[RequestPrompt, PromptType],
prompt: Union[RequestPrompt, PromptType],
) -> PromptComponents:
if isinstance(inputs, str):
return PromptComponents(text=inputs)
if isinstance(inputs, list):
return PromptComponents(token_ids=inputs)
if isinstance(inputs, dict):
return PromptComponents(
text=inputs.get("prompt"), # type: ignore[arg-type]
token_ids=inputs.get(
"prompt_token_ids"), # type: ignore[arg-type]
embeds=inputs.get("prompt_embeds"),
)
if isinstance(prompt, list):
return PromptComponents(token_ids=prompt)
return PromptComponents(
text=getattr(inputs, "prompt", None),
token_ids=getattr(inputs, "prompt_token_ids", None),
embeds=getattr(inputs, "prompt_embeds", None),
)
return get_prompt_components(prompt) # type: ignore[arg-type]
def _log_inputs(
self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
Union, cast, overload)
from typing_extensions import TypeIs
......@@ -11,6 +12,9 @@ from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
TokensPrompt)
if TYPE_CHECKING:
import torch
class ParsedText(TypedDict):
content: str
......@@ -149,3 +153,23 @@ def split_enc_dec_inputs(
)
return None, inputs
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional["torch.Tensor"] = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if (encoder_prompt := prompt.get("encoder_prompt")):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)
......@@ -27,6 +27,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
......@@ -213,13 +214,14 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: PromptType,
prompt: Union[EngineCoreRequest, PromptType],
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
prompt_text: Optional[str] = None,
) -> None:
# Validate the request_id type.
if not isinstance(request_id, str):
......@@ -227,12 +229,18 @@ class LLMEngine:
f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority)
prompt_text = prompt if isinstance(prompt,
str) else prompt.get("prompt")
if isinstance(prompt, EngineCoreRequest):
request = prompt
else:
assert prompt_text is None
logger.warning_once("Processor has been moved under LLM and will "
"be removed from LLMEngine in v0.13.")
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority)
prompt_text = (prompt if isinstance(prompt, str) else
prompt.get("prompt"))
n = params.n if isinstance(params, SamplingParams) else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment