serving_tokenization.py 7.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from dataclasses import dataclass
from typing import Any, Final, Optional, Union
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
12
from vllm.entrypoints.logger import RequestLogger
13
14
# yapf conflicts with isort for this block
# yapf: disable
15
16
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
17
18
                                              ErrorResponse,
                                              TokenizeChatRequest,
19
                                              TokenizeRequest,
20
21
                                              TokenizeResponse,
                                              TokenizerInfoResponse)
22
# yapf: enable
23
24
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
25
from vllm.entrypoints.renderer import RenderConfig
26
from vllm.logger import init_logger
27
from vllm.transformers_utils.tokenizer import AnyTokenizer
28

29
30
logger = init_logger(__name__)

31
32
33

class OpenAIServingTokenization(OpenAIServing):

34
35
    def __init__(
        self,
36
        engine_client: EngineClient,
37
        model_config: ModelConfig,
38
        models: OpenAIServingModels,
39
40
41
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
42
        chat_template_content_format: ChatTemplateContentFormatOption,
43
        log_error_stack: bool = False,
44
    ) -> None:
45
        super().__init__(engine_client=engine_client,
46
                         model_config=model_config,
47
                         models=models,
48
49
                         request_logger=request_logger,
                         log_error_stack=log_error_stack)
50

51
52
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
53

54
55
56
    async def create_tokenize(
        self,
        request: TokenizeRequest,
57
        raw_request: Request,
58
    ) -> Union[TokenizeResponse, ErrorResponse]:
59
60
61
62
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

63
        request_id = f"tokn-{self._base_request_id(raw_request)}"
64

65
        try:
66
            lora_request = self._maybe_get_adapters(request)
67
68

            tokenizer = await self.engine_client.get_tokenizer(lora_request)
69
            renderer = self._get_renderer(tokenizer)
70
71

            if isinstance(request, TokenizeChatRequest):
72
73
                tool_dicts = (None if request.tools is None else
                              [tool.model_dump() for tool in request.tools])
74
75
                (
                    _,
76
                    _,
77
78
79
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
80
                    tokenizer,
81
                    request.messages,
82
                    tool_dicts=tool_dicts,
83
84
85
                    chat_template=request.chat_template or self.chat_template,
                    chat_template_content_format=self.
                    chat_template_content_format,
86
                    add_generation_prompt=request.add_generation_prompt,
87
                    continue_final_message=request.continue_final_message,
88
                    chat_template_kwargs=request.chat_template_kwargs,
89
                    add_special_tokens=request.add_special_tokens,
90
91
                )
            else:
92
93
                engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=request.prompt,
94
                    config=self._build_render_config(request),
95
                )
96
        except (ValueError, TypeError, jinja2.TemplateError) as e:
97
            logger.exception("Error in preprocessing prompt inputs")
98
            return self.create_error_response(f"{e} {e.__cause__}")
99

100
        input_ids: list[int] = []
101
        for engine_prompt in engine_prompts:
102
            self._log_inputs(request_id,
103
                             engine_prompt,
104
                             params=None,
105
                             lora_request=lora_request)
106

107
108
109
            if isinstance(engine_prompt,
                          dict) and "prompt_token_ids" in engine_prompt:
                input_ids.extend(engine_prompt["prompt_token_ids"])
110

111
112
113
114
        token_strs = None
        if request.return_token_strs:
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

115
        return TokenizeResponse(tokens=input_ids,
116
                                token_strs=token_strs,
117
118
119
120
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
121
122
        self,
        request: DetokenizeRequest,
123
        raw_request: Request,
124
    ) -> Union[DetokenizeResponse, ErrorResponse]:
125
126
127
128
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

129
        request_id = f"tokn-{self._base_request_id(raw_request)}"
130

131
        lora_request = self._maybe_get_adapters(request)
132

133
        tokenizer = await self.engine_client.get_tokenizer(lora_request)
134
135
136
137

        self._log_inputs(request_id,
                         request.tokens,
                         params=None,
138
                         lora_request=lora_request)
139

140
        prompt_input = await self._tokenize_prompt_input_async(
141
142
143
144
145
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
146
147

        return DetokenizeResponse(prompt=input_text)
148
149
150
151
152
153
154
155
156
157
158
159

    async def get_tokenizer_info(
        self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
        """Get comprehensive tokenizer information."""
        try:
            tokenizer = await self.engine_client.get_tokenizer()
            info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
            return TokenizerInfoResponse(**info)
        except Exception as e:
            return self.create_error_response(
                f"Failed to get tokenizer info: {str(e)}")

160
161
162
    def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
        return RenderConfig(add_special_tokens=request.add_special_tokens)

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

@dataclass
class TokenizerInfo:
    tokenizer: AnyTokenizer
    chat_template: Optional[str]

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj