Unverified Commit 739b61a3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Refactor prompt processing (#4028)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 89c1c6a1
from typing import List, Optional from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template, load_chat_template,
parse_chat_message_content) parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest, from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.utils import random_uuid
class OpenAIServingTokenization(OpenAIServing): class OpenAIServingTokenization(OpenAIServing):
def __init__(self, def __init__(
engine: AsyncLLMEngine, self,
model_config: ModelConfig, engine: AsyncLLMEngine,
served_model_names: List[str], model_config: ModelConfig,
lora_modules: Optional[List[LoRAModulePath]] = None, served_model_names: List[str],
chat_template: Optional[str] = None): *,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template # If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template) self.chat_template = load_chat_template(chat_template)
async def create_tokenize(self, async def create_tokenize(
request: TokenizeRequest) -> TokenizeResponse: self,
request: TokenizeRequest,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if not (request.prompt or request.messages): request_id = f"tokn-{random_uuid()}"
return self.create_error_response(
"Either `prompt` or `messages` should be provided.")
if (request.prompt and request.messages): (
return self.create_error_response( lora_request,
"Only one of `prompt` or `messages` should be provided.") prompt_adapter_request,
) = self._maybe_get_adapters(request)
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.engine.get_tokenizer(lora_request)
if request.messages:
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
for message in request.messages: for message in request.messages:
result = parse_chat_message_content(message, self.model_config, result = parse_chat_message_content(message, model_config,
tokenizer) tokenizer)
conversation.extend(result.messages) conversation.extend(result.messages)
request.prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
conversation=conversation, conversation=conversation,
tokenize=False, tokenize=False,
chat_template=self.chat_template) chat_template=self.chat_template)
assert isinstance(prompt, str)
else:
prompt = request.prompt
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize( # Silently ignore prompt adapter since it does not affect tokenization
prompt_input = self._tokenize_prompt_input(
request, request,
tokenizer, tokenizer,
prompt=request.prompt, prompt,
add_special_tokens=request.add_special_tokens) add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids, return TokenizeResponse(tokens=input_ids,
count=len(input_ids), count=len(input_ids),
max_model_len=self.max_model_len) max_model_len=self.max_model_len)
async def create_detokenize( async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse: self,
request: DetokenizeRequest,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, lora_request = self._maybe_get_adapter(request) request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.engine.get_tokenizer(lora_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, tokenizer, prompt_ids=request.tokens) self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text) return DetokenizeResponse(prompt=input_text)
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt, TextPrompt, TokensPrompt, parse_and_batch_prompt)
TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -14,6 +13,6 @@ See also: ...@@ -14,6 +13,6 @@ See also:
__all__ = [ __all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" "InputContext", "InputRegistry"
] ]
...@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict): ...@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
""" """
class TextTokensPrompt(TypedDict): PromptInputs = Union[str, TextPrompt, TokensPrompt]
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
""" """
The inputs to the LLM, which can take one of the following forms: The inputs to the LLM, which can take one of the following forms:
...@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms: ...@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
""" """
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
""" """
......
...@@ -5,7 +5,8 @@ import math ...@@ -5,7 +5,8 @@ import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
import torch import torch
...@@ -438,7 +439,7 @@ class SequenceGroup: ...@@ -438,7 +439,7 @@ class SequenceGroup:
embeddings: Optional[List[float]] = None, embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment