serving.py 7.28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass
4
from typing import Any, Final
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.engine.protocol import EngineClient
10
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
11
from vllm.entrypoints.logger import RequestLogger
12
13
14
15
16
17
18
19
20
from vllm.entrypoints.openai.protocol import (
    DetokenizeRequest,
    DetokenizeResponse,
    ErrorResponse,
    TokenizeChatRequest,
    TokenizeRequest,
    TokenizeResponse,
    TokenizerInfoResponse,
)
21
22
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
23
from vllm.entrypoints.renderer import RenderConfig
24
from vllm.logger import init_logger
25
from vllm.tokenizers import TokenizerLike
26

27
28
logger = init_logger(__name__)

29
30

class OpenAIServingTokenization(OpenAIServing):
31
32
    def __init__(
        self,
33
        engine_client: EngineClient,
34
        models: OpenAIServingModels,
35
        *,
36
37
        request_logger: RequestLogger | None,
        chat_template: str | None,
38
        chat_template_content_format: ChatTemplateContentFormatOption,
39
        trust_request_chat_template: bool = False,
40
        log_error_stack: bool = False,
41
    ) -> None:
42
43
44
45
46
47
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
48

49
50
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
51
        self.trust_request_chat_template = trust_request_chat_template
52

53
54
55
    async def create_tokenize(
        self,
        request: TokenizeRequest,
56
        raw_request: Request,
57
    ) -> TokenizeResponse | ErrorResponse:
58
59
60
61
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

62
        request_id = f"tokn-{self._base_request_id(raw_request)}"
63

64
        try:
65
            lora_request = self._maybe_get_adapters(request)
66

67
            tokenizer = await self.engine_client.get_tokenizer()
68
            renderer = self._get_renderer(tokenizer)
69
70

            if isinstance(request, TokenizeChatRequest):
71
72
73
74
75
                tool_dicts = (
                    None
                    if request.tools is None
                    else [tool.model_dump() for tool in request.tools]
                )
76
77
78
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
79
                    trust_request_chat_template=self.trust_request_chat_template,
80
81
82
                )
                if error_check_ret is not None:
                    return error_check_ret
83
84
                (
                    _,
85
                    _,
86
87
88
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
89
                    tokenizer,
90
                    request.messages,
91
                    tool_dicts=tool_dicts,
92
                    chat_template=request.chat_template or self.chat_template,
93
                    chat_template_content_format=self.chat_template_content_format,
94
                    add_generation_prompt=request.add_generation_prompt,
95
                    continue_final_message=request.continue_final_message,
96
                    chat_template_kwargs=request.chat_template_kwargs,
97
                    add_special_tokens=request.add_special_tokens,
98
99
                )
            else:
100
101
                engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=request.prompt,
102
                    config=self._build_render_config(request),
103
                )
104
        except (ValueError, TypeError, jinja2.TemplateError) as e:
105
            logger.exception("Error in preprocessing prompt inputs")
106
            return self.create_error_response(f"{e} {e.__cause__}")
107

108
        input_ids: list[int] = []
109
        for engine_prompt in engine_prompts:
110
111
112
            self._log_inputs(
                request_id, engine_prompt, params=None, lora_request=lora_request
            )
113

114
            if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
115
                input_ids.extend(engine_prompt["prompt_token_ids"])
116

117
118
119
120
        token_strs = None
        if request.return_token_strs:
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

121
122
123
124
125
126
        return TokenizeResponse(
            tokens=input_ids,
            token_strs=token_strs,
            count=len(input_ids),
            max_model_len=self.max_model_len,
        )
127
128

    async def create_detokenize(
129
130
        self,
        request: DetokenizeRequest,
131
        raw_request: Request,
132
    ) -> DetokenizeResponse | ErrorResponse:
133
134
135
136
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

137
        request_id = f"tokn-{self._base_request_id(raw_request)}"
138

139
        lora_request = self._maybe_get_adapters(request)
140

141
        tokenizer = await self.engine_client.get_tokenizer()
142

143
144
145
        self._log_inputs(
            request_id, request.tokens, params=None, lora_request=lora_request
        )
146

147
        prompt_input = await self._tokenize_prompt_input_async(
148
149
150
151
152
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
153
154

        return DetokenizeResponse(prompt=input_text)
155
156

    async def get_tokenizer_info(
157
        self,
158
    ) -> TokenizerInfoResponse | ErrorResponse:
159
160
161
162
163
164
        """Get comprehensive tokenizer information."""
        try:
            tokenizer = await self.engine_client.get_tokenizer()
            info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
            return TokenizerInfoResponse(**info)
        except Exception as e:
165
            return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
166

167
168
169
    def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
        return RenderConfig(add_special_tokens=request.add_special_tokens)

170
171
172

@dataclass
class TokenizerInfo:
173
    tokenizer: TokenizerLike
174
    chat_template: str | None
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj