Unverified Commit 8c025fa7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Factor out chat message parsing (#7055)

parent 69ea15e5
import codecs import codecs
from dataclasses import dataclass, field from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict): ...@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatMessageParseResult: class ChatMessageParseResult:
messages: List[ConversationMessage] messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field( mm_futures: List[Awaitable[MultiModalDataDict]]
default_factory=list)
def load_chat_template(chat_template: Optional[str]) -> Optional[str]: def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
...@@ -174,7 +174,7 @@ def _parse_chat_message_content_parts( ...@@ -174,7 +174,7 @@ def _parse_chat_message_content_parts(
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content( def _parse_chat_message_content(
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
...@@ -190,3 +190,21 @@ def parse_chat_message_content( ...@@ -190,3 +190,21 @@ def parse_chat_message_content(
return _parse_chat_message_content_parts(role, content, model_config, return _parse_chat_message_content_parts(role, content, model_config,
tokenizer) tokenizer)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
import time import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
...@@ -11,7 +10,7 @@ from vllm.config import ModelConfig ...@@ -11,7 +10,7 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template, load_chat_template,
parse_chat_message_content) parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProb, ChatCompletionLogProbs,
...@@ -92,15 +91,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -92,15 +91,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = await self.async_engine_client.get_tokenizer( tokenizer = await self.async_engine_client.get_tokenizer(
lora_request) lora_request)
conversation: List[ConversationMessage] = [] conversation, mm_futures = parse_chat_messages(
mm_futures: List[Awaitable[MultiModalDataDict]] = [] request.messages, model_config, tokenizer)
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
tool_dicts = None if request.tools is None else [ tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools tool.model_dump() for tool in request.tools
...@@ -115,6 +107,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -115,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
assert isinstance(prompt, str)
except Exception as e: except Exception as e:
logger.error("Error in applying chat template from request: %s", e) logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
from typing import List, Optional, Union from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (DetokenizeRequest, from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
ErrorResponse, ErrorResponse,
...@@ -17,8 +15,11 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, ...@@ -17,8 +15,11 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing): class OpenAIServingTokenization(OpenAIServing):
...@@ -62,12 +63,12 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -62,12 +63,12 @@ class OpenAIServingTokenization(OpenAIServing):
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
model_config = self.model_config model_config = self.model_config
conversation: List[ConversationMessage] = [] conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
for message in request.messages: if mm_futures:
result = parse_chat_message_content(message, model_config, logger.warning(
tokenizer) "Multi-modal inputs are ignored during tokenization")
conversation.extend(result.messages)
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment