chat_utils.py 5.5 KB
Newer Older
1
2
3
4
5
6
7
import codecs
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final

from openai.types.chat import (ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartTextParam)
8
from transformers import PreTrainedTokenizer
9

10
from vllm.config import ModelConfig
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
                                              ChatCompletionMessageParam)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image

logger = init_logger(__name__)


@final  # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
    role: str
    content: str


@dataclass(frozen=True)
class ChatMessageParseResult:
    messages: List[ConversationMessage]
    mm_futures: List[Awaitable[MultiModalDataDict]] = field(
        default_factory=list)


33
34
35
36
37
38
39
40
41
42
43
44
45
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
    if chat_template is None:
        return None
    try:
        with open(chat_template, "r") as f:
            resolved_chat_template = f.read()
    except OSError as e:
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
            msg = (f"The supplied chat template ({chat_template}) "
                   f"looks like a file path, but it failed to be "
                   f"opened. Reason: {e}")
            raise ValueError(msg) from e
46

47
48
49
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
        resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
50

51
52
    logger.info("Using supplied chat template:\n%s", resolved_chat_template)
    return resolved_chat_template
53
54
55


@lru_cache(maxsize=None)
56
57
def _image_token_str(model_config: ModelConfig,
                     tokenizer: PreTrainedTokenizer) -> Optional[str]:
58
59
    # TODO: Let user specify how to insert image tokens into prompt
    # (similar to chat template)
60
    model_type = model_config.hf_config.model_type
61
62
63
64
65
66
67
    if model_type == "phi3_v":
        # Workaround since this token is not defined in the tokenizer
        return "<|image_1|>"
    if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
        # These models do not use image tokens in the prompt
        return None
    if model_type.startswith("llava"):
68
        return tokenizer.decode(model_config.hf_config.image_token_index)
69

70
    raise TypeError("Unknown model type: {model_type}")
71
72
73
74


# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
75
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
76
77
78
79
80
81
82
83
84
85
    """Combine image and text prompts for vision language model"""

    # NOTE: For now we assume all model architectures use the same
    # image + text prompt format. This may change in the future.
    return f"{image_token_str}\n{text_prompt}"


def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
86
87
    model_config: ModelConfig,
    tokenizer: PreTrainedTokenizer,
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
) -> ChatMessageParseResult:
    texts: List[str] = []
    mm_futures: List[Awaitable[MultiModalDataDict]] = []

    for part in parts:
        part_type = part["type"]
        if part_type == "text":
            text = cast(ChatCompletionContentPartTextParam, part)["text"]
            texts.append(text)
        elif part_type == "image_url":
            if len(mm_futures) > 0:
                raise NotImplementedError(
                    "Multiple 'image_url' input is currently not supported.")

            image_url = cast(ChatCompletionContentPartImageParam,
                             part)["image_url"]

            if image_url.get("detail", "auto") != "auto":
                logger.warning(
                    "'image_url.detail' is currently not supported and "
                    "will be ignored.")

            image_future = async_get_and_parse_image(image_url["url"])
            mm_futures.append(image_future)
        else:
            raise NotImplementedError(f"Unknown part type: {part_type}")

    text_prompt = "\n".join(texts)

    if mm_futures:
118
        image_token_str = _image_token_str(model_config, tokenizer)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        if image_token_str is not None:
            if image_token_str in text_prompt:
                logger.warning(
                    "Detected image token string in the text prompt. "
                    "Skipping prompt formatting.")
            else:
                text_prompt = _get_full_image_text_prompt(
                    image_token_str=image_token_str,
                    text_prompt=text_prompt,
                )

    messages = [ConversationMessage(role=role, content=text_prompt)]

    return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)


def parse_chat_message_content(
    message: ChatCompletionMessageParam,
137
138
    model_config: ModelConfig,
    tokenizer: PreTrainedTokenizer,
139
140
141
142
143
144
145
146
147
148
) -> ChatMessageParseResult:
    role = message["role"]
    content = message.get("content")

    if content is None:
        return ChatMessageParseResult(messages=[], mm_futures=[])
    if isinstance(content, str):
        messages = [ConversationMessage(role=role, content=content)]
        return ChatMessageParseResult(messages=messages, mm_futures=[])

149
150
    return _parse_chat_message_content_parts(role, content, model_config,
                                             tokenizer)