Unverified Commit 8c025fa7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Factor out chat message parsing (#7055)

parent 69ea15e5
import codecs
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
# yapf conflicts with isort for this block
# yapf: disable
......@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
......@@ -174,7 +174,7 @@ def _parse_chat_message_content_parts(
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
......@@ -190,3 +190,21 @@ def parse_chat_message_content(
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
Optional)
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
......@@ -11,7 +10,7 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
......@@ -92,15 +91,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
......@@ -115,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
......
from typing import List, Optional, Union
from vllm.config import ModelConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
......@@ -17,8 +15,11 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing):
......@@ -62,12 +63,12 @@ class OpenAIServingTokenization(OpenAIServing):
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = []
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
for message in request.messages:
result = parse_chat_message_content(message, model_config,
tokenizer)
conversation.extend(result.messages)
if mm_futures:
logger.warning(
"Multi-modal inputs are ignored during tokenization")
prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment