serving.py 7.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass
4
from typing import Any, Final
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.engine.protocol import EngineClient
10
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
11
from vllm.entrypoints.logger import RequestLogger
12
13
14
15
from vllm.entrypoints.openai.engine.protocol import (
    ErrorResponse,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing
16
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
17
18
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.serve.tokenize.protocol import (
19
20
21
22
23
24
25
    DetokenizeRequest,
    DetokenizeResponse,
    TokenizeChatRequest,
    TokenizeRequest,
    TokenizeResponse,
    TokenizerInfoResponse,
)
26
from vllm.inputs import TokensPrompt
27
from vllm.logger import init_logger
28
from vllm.tokenizers import TokenizerLike
29

30
31
logger = init_logger(__name__)

32
33

class OpenAIServingTokenization(OpenAIServing):
34
35
    def __init__(
        self,
36
        engine_client: EngineClient,
37
        models: OpenAIServingModels,
38
        *,
39
40
        request_logger: RequestLogger | None,
        chat_template: str | None,
41
        chat_template_content_format: ChatTemplateContentFormatOption,
42
        trust_request_chat_template: bool = False,
43
        log_error_stack: bool = False,
44
    ) -> None:
45
46
47
48
49
50
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
51

52
53
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
54
        self.trust_request_chat_template = trust_request_chat_template
55

56
57
58
    async def create_tokenize(
        self,
        request: TokenizeRequest,
59
        raw_request: Request,
60
    ) -> TokenizeResponse | ErrorResponse:
61
62
63
64
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

65
        request_id = f"tokenize-{self._base_request_id(raw_request)}"
66

67
        try:
68
            lora_request = self._maybe_get_adapters(request)
69
70

            if isinstance(request, TokenizeChatRequest):
71
72
73
74
75
                tool_dicts = (
                    None
                    if request.tools is None
                    else [tool.model_dump() for tool in request.tools]
                )
76
77
78
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
79
                    trust_request_chat_template=self.trust_request_chat_template,
80
81
82
                )
                if error_check_ret is not None:
                    return error_check_ret
83
84

                _, engine_prompts = await self._preprocess_chat(
85
                    request,
86
                    self.renderer,
87
                    request.messages,
88
                    tool_dicts=tool_dicts,
89
                    chat_template=request.chat_template or self.chat_template,
90
                    chat_template_content_format=self.chat_template_content_format,
91
                    add_generation_prompt=request.add_generation_prompt,
92
                    continue_final_message=request.continue_final_message,
93
                    chat_template_kwargs=request.chat_template_kwargs,
94
                    add_special_tokens=request.add_special_tokens,
95
96
                )
            else:
97
                renderer = self._get_completion_renderer()
98
99
                engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=request.prompt,
100
                    config=self._build_render_config(request),
101
                )
102
        except (ValueError, TypeError, jinja2.TemplateError) as e:
103
            logger.exception("Error in preprocessing prompt inputs")
104
            return self.create_error_response(f"{e} {e.__cause__}")
105

106
        input_ids: list[int] = []
107
        for engine_prompt in engine_prompts:
108
109
110
            self._log_inputs(
                request_id, engine_prompt, params=None, lora_request=lora_request
            )
111

112
            if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
113
                input_ids.extend(engine_prompt["prompt_token_ids"])
114

115
116
        token_strs = None
        if request.return_token_strs:
117
            tokenizer = self.renderer.get_tokenizer()
118
119
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

120
121
122
123
124
125
        return TokenizeResponse(
            tokens=input_ids,
            token_strs=token_strs,
            count=len(input_ids),
            max_model_len=self.max_model_len,
        )
126
127

    async def create_detokenize(
128
129
        self,
        request: DetokenizeRequest,
130
        raw_request: Request,
131
    ) -> DetokenizeResponse | ErrorResponse:
132
133
134
135
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

136
        request_id = f"tokenize-{self._base_request_id(raw_request)}"
137

138
        lora_request = self._maybe_get_adapters(request)
139
        tokenizer = self.renderer.get_tokenizer()
140

141
        self._log_inputs(
142
143
144
145
            request_id,
            TokensPrompt(prompt_token_ids=request.tokens),
            params=None,
            lora_request=lora_request,
146
        )
147

148
        prompt_input = await self._tokenize_prompt_input_async(
149
150
151
152
153
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
154
155

        return DetokenizeResponse(prompt=input_text)
156
157

    async def get_tokenizer_info(
158
        self,
159
    ) -> TokenizerInfoResponse | ErrorResponse:
160
161
        """Get comprehensive tokenizer information."""
        try:
162
            tokenizer = self.renderer.get_tokenizer()
163
164
165
            info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
            return TokenizerInfoResponse(**info)
        except Exception as e:
166
            return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
167

168
169
170
    def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
        return RenderConfig(add_special_tokens=request.add_special_tokens)

171
172
173

@dataclass
class TokenizerInfo:
174
    tokenizer: TokenizerLike
175
    chat_template: str | None
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj