serving_tokenization.py 7.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from dataclasses import dataclass
from typing import Any, Final, Optional, Union
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
12
from vllm.entrypoints.logger import RequestLogger
13

14
15
# yapf conflicts with isort for this block
# yapf: disable
16
17
18
19
20
21
22
23
24
25
from vllm.entrypoints.openai.protocol import (
    DetokenizeRequest,
    DetokenizeResponse,
    ErrorResponse,
    TokenizeChatRequest,
    TokenizeRequest,
    TokenizeResponse,
    TokenizerInfoResponse,
)

26
# yapf: enable
27
28
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
29
from vllm.entrypoints.renderer import RenderConfig
30
from vllm.logger import init_logger
31
from vllm.transformers_utils.tokenizer import AnyTokenizer
32

33
34
logger = init_logger(__name__)

35
36

class OpenAIServingTokenization(OpenAIServing):
37
38
    def __init__(
        self,
39
        engine_client: EngineClient,
40
        model_config: ModelConfig,
41
        models: OpenAIServingModels,
42
43
44
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
45
        chat_template_content_format: ChatTemplateContentFormatOption,
46
        trust_request_chat_template: bool = False,
47
        log_error_stack: bool = False,
48
    ) -> None:
49
50
51
52
53
54
55
        super().__init__(
            engine_client=engine_client,
            model_config=model_config,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
56

57
58
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
59
        self.trust_request_chat_template = trust_request_chat_template
60

61
62
63
    async def create_tokenize(
        self,
        request: TokenizeRequest,
64
        raw_request: Request,
65
    ) -> Union[TokenizeResponse, ErrorResponse]:
66
67
68
69
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

70
        request_id = f"tokn-{self._base_request_id(raw_request)}"
71

72
        try:
73
            lora_request = self._maybe_get_adapters(request)
74

75
            tokenizer = await self.engine_client.get_tokenizer()
76
            renderer = self._get_renderer(tokenizer)
77
78

            if isinstance(request, TokenizeChatRequest):
79
80
81
82
83
                tool_dicts = (
                    None
                    if request.tools is None
                    else [tool.model_dump() for tool in request.tools]
                )
84
85
86
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
87
                    trust_request_chat_template=self.trust_request_chat_template,
88
89
90
                )
                if error_check_ret is not None:
                    return error_check_ret
91
92
                (
                    _,
93
                    _,
94
95
96
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
97
                    tokenizer,
98
                    request.messages,
99
                    tool_dicts=tool_dicts,
100
                    chat_template=request.chat_template or self.chat_template,
101
                    chat_template_content_format=self.chat_template_content_format,
102
                    add_generation_prompt=request.add_generation_prompt,
103
                    continue_final_message=request.continue_final_message,
104
                    chat_template_kwargs=request.chat_template_kwargs,
105
                    add_special_tokens=request.add_special_tokens,
106
107
                )
            else:
108
109
                engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=request.prompt,
110
                    config=self._build_render_config(request),
111
                )
112
        except (ValueError, TypeError, jinja2.TemplateError) as e:
113
            logger.exception("Error in preprocessing prompt inputs")
114
            return self.create_error_response(f"{e} {e.__cause__}")
115

116
        input_ids: list[int] = []
117
        for engine_prompt in engine_prompts:
118
119
120
            self._log_inputs(
                request_id, engine_prompt, params=None, lora_request=lora_request
            )
121

122
            if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
123
                input_ids.extend(engine_prompt["prompt_token_ids"])
124

125
126
127
128
        token_strs = None
        if request.return_token_strs:
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

129
130
131
132
133
134
        return TokenizeResponse(
            tokens=input_ids,
            token_strs=token_strs,
            count=len(input_ids),
            max_model_len=self.max_model_len,
        )
135
136

    async def create_detokenize(
137
138
        self,
        request: DetokenizeRequest,
139
        raw_request: Request,
140
    ) -> Union[DetokenizeResponse, ErrorResponse]:
141
142
143
144
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

145
        request_id = f"tokn-{self._base_request_id(raw_request)}"
146

147
        lora_request = self._maybe_get_adapters(request)
148

149
        tokenizer = await self.engine_client.get_tokenizer()
150

151
152
153
        self._log_inputs(
            request_id, request.tokens, params=None, lora_request=lora_request
        )
154

155
        prompt_input = await self._tokenize_prompt_input_async(
156
157
158
159
160
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
161
162

        return DetokenizeResponse(prompt=input_text)
163
164

    async def get_tokenizer_info(
165
166
        self,
    ) -> Union[TokenizerInfoResponse, ErrorResponse]:
167
168
169
170
171
172
        """Get comprehensive tokenizer information."""
        try:
            tokenizer = await self.engine_client.get_tokenizer()
            info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
            return TokenizerInfoResponse(**info)
        except Exception as e:
173
            return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
174

175
176
177
    def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
        return RenderConfig(add_special_tokens=request.add_special_tokens)

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

@dataclass
class TokenizerInfo:
    tokenizer: AnyTokenizer
    chat_template: Optional[str]

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj