serving_tokenization.py 7.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from dataclasses import dataclass
from typing import Any, Final, Optional, Union
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
12
from vllm.entrypoints.logger import RequestLogger
13
14
# yapf conflicts with isort for this block
# yapf: disable
15
16
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
17
18
                                              ErrorResponse,
                                              TokenizeChatRequest,
19
                                              TokenizeRequest,
20
21
                                              TokenizeResponse,
                                              TokenizerInfoResponse)
22
# yapf: enable
23
24
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
25
from vllm.entrypoints.renderer import RenderConfig
26
from vllm.logger import init_logger
27
from vllm.transformers_utils.tokenizer import AnyTokenizer
28

29
30
logger = init_logger(__name__)

31
32
33

class OpenAIServingTokenization(OpenAIServing):

34
35
    def __init__(
        self,
36
        engine_client: EngineClient,
37
        model_config: ModelConfig,
38
        models: OpenAIServingModels,
39
40
41
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
42
        chat_template_content_format: ChatTemplateContentFormatOption,
43
        trust_request_chat_template: bool = False,
44
        log_error_stack: bool = False,
45
    ) -> None:
46
        super().__init__(engine_client=engine_client,
47
                         model_config=model_config,
48
                         models=models,
49
50
                         request_logger=request_logger,
                         log_error_stack=log_error_stack)
51

52
53
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
54
        self.trust_request_chat_template = trust_request_chat_template
55

56
57
58
    async def create_tokenize(
        self,
        request: TokenizeRequest,
59
        raw_request: Request,
60
    ) -> Union[TokenizeResponse, ErrorResponse]:
61
62
63
64
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

65
        request_id = f"tokn-{self._base_request_id(raw_request)}"
66

67
        try:
68
            lora_request = self._maybe_get_adapters(request)
69

70
            tokenizer = await self.engine_client.get_tokenizer()
71
            renderer = self._get_renderer(tokenizer)
72
73

            if isinstance(request, TokenizeChatRequest):
74
75
                tool_dicts = (None if request.tools is None else
                              [tool.model_dump() for tool in request.tools])
76
77
78
79
80
81
82
83
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
                    trust_request_chat_template=self.
                    trust_request_chat_template,
                )
                if error_check_ret is not None:
                    return error_check_ret
84
85
                (
                    _,
86
                    _,
87
88
89
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
90
                    tokenizer,
91
                    request.messages,
92
                    tool_dicts=tool_dicts,
93
94
95
                    chat_template=request.chat_template or self.chat_template,
                    chat_template_content_format=self.
                    chat_template_content_format,
96
                    add_generation_prompt=request.add_generation_prompt,
97
                    continue_final_message=request.continue_final_message,
98
                    chat_template_kwargs=request.chat_template_kwargs,
99
                    add_special_tokens=request.add_special_tokens,
100
101
                )
            else:
102
103
                engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=request.prompt,
104
                    config=self._build_render_config(request),
105
                )
106
        except (ValueError, TypeError, jinja2.TemplateError) as e:
107
            logger.exception("Error in preprocessing prompt inputs")
108
            return self.create_error_response(f"{e} {e.__cause__}")
109

110
        input_ids: list[int] = []
111
        for engine_prompt in engine_prompts:
112
            self._log_inputs(request_id,
113
                             engine_prompt,
114
                             params=None,
115
                             lora_request=lora_request)
116

117
118
119
            if isinstance(engine_prompt,
                          dict) and "prompt_token_ids" in engine_prompt:
                input_ids.extend(engine_prompt["prompt_token_ids"])
120

121
122
123
124
        token_strs = None
        if request.return_token_strs:
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

125
        return TokenizeResponse(tokens=input_ids,
126
                                token_strs=token_strs,
127
128
129
130
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
131
132
        self,
        request: DetokenizeRequest,
133
        raw_request: Request,
134
    ) -> Union[DetokenizeResponse, ErrorResponse]:
135
136
137
138
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

139
        request_id = f"tokn-{self._base_request_id(raw_request)}"
140

141
        lora_request = self._maybe_get_adapters(request)
142

143
        tokenizer = await self.engine_client.get_tokenizer()
144
145
146
147

        self._log_inputs(request_id,
                         request.tokens,
                         params=None,
148
                         lora_request=lora_request)
149

150
        prompt_input = await self._tokenize_prompt_input_async(
151
152
153
154
155
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
156
157

        return DetokenizeResponse(prompt=input_text)
158
159
160
161
162
163
164
165
166
167
168
169

    async def get_tokenizer_info(
        self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
        """Get comprehensive tokenizer information."""
        try:
            tokenizer = await self.engine_client.get_tokenizer()
            info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
            return TokenizerInfoResponse(**info)
        except Exception as e:
            return self.create_error_response(
                f"Failed to get tokenizer info: {str(e)}")

170
171
172
    def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
        return RenderConfig(add_special_tokens=request.add_special_tokens)

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

@dataclass
class TokenizerInfo:
    tokenizer: AnyTokenizer
    chat_template: Optional[str]

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj