Unverified Commit 739b61a3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Refactor prompt processing (#4028)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 89c1c6a1
from typing import List, Optional
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.utils import random_uuid
class OpenAIServingTokenization(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
async def create_tokenize(
self,
request: TokenizeRequest,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
if not (request.prompt or request.messages):
return self.create_error_response(
"Either `prompt` or `messages` should be provided.")
request_id = f"tokn-{random_uuid()}"
if (request.prompt and request.messages):
return self.create_error_response(
"Only one of `prompt` or `messages` should be provided.")
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
if request.messages:
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = []
for message in request.messages:
result = parse_chat_message_content(message, self.model_config,
result = parse_chat_message_content(message, model_config,
tokenizer)
conversation.extend(result.messages)
request.prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
else:
prompt = request.prompt
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
self,
request: DetokenizeRequest,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
_, lora_request = self._maybe_get_adapter(request)
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, tokenizer, prompt_ids=request.tokens)
self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt,
TokensPrompt, parse_and_batch_prompt)
TextPrompt, TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
......@@ -14,6 +13,6 @@ See also:
__all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
"InputContext", "InputRegistry"
]
......@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
PromptInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
......@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
"""
......
......@@ -5,7 +5,8 @@ import math
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
import torch
......@@ -438,7 +439,7 @@ class SequenceGroup:
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment