serving.py 6.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass
4
from typing import Any, Final
5

6
7
from fastapi import Request

8
from vllm.engine.protocol import EngineClient
9
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
10
from vllm.entrypoints.logger import RequestLogger
11
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
12
from vllm.entrypoints.openai.engine.serving import OpenAIServing
13
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
14
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
15
from vllm.entrypoints.serve.tokenize.protocol import (
16
17
18
19
20
21
22
    DetokenizeRequest,
    DetokenizeResponse,
    TokenizeChatRequest,
    TokenizeRequest,
    TokenizeResponse,
    TokenizerInfoResponse,
)
23
from vllm.inputs import TokensPrompt, tokens_input
24
from vllm.logger import init_logger
25
from vllm.tokenizers import TokenizerLike
26

27
28
logger = init_logger(__name__)

29
30

class OpenAIServingTokenization(OpenAIServing):
31
32
    def __init__(
        self,
33
        engine_client: EngineClient,
34
        models: OpenAIServingModels,
35
        openai_serving_render: OpenAIServingRender,
36
        *,
37
38
        request_logger: RequestLogger | None,
        chat_template: str | None,
39
        chat_template_content_format: ChatTemplateContentFormatOption,
40
        default_chat_template_kwargs: dict[str, Any] | None = None,
41
        trust_request_chat_template: bool = False,
42
    ) -> None:
43
44
45
46
47
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
        )
48

49
        self.openai_serving_render = openai_serving_render
50
51
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
52
        self.default_chat_template_kwargs = default_chat_template_kwargs or {}
53
        self.trust_request_chat_template = trust_request_chat_template
54

55
56
57
    async def create_tokenize(
        self,
        request: TokenizeRequest,
58
        raw_request: Request,
59
    ) -> TokenizeResponse | ErrorResponse:
60
61
62
63
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

64
        request_id = f"tokenize-{self._base_request_id(raw_request)}"
65

66
67
68
69
70
71
72
73
        lora_request = self._maybe_get_adapters(request)

        if isinstance(request, TokenizeChatRequest):
            tool_dicts = (
                None
                if request.tools is None
                else [tool.model_dump() for tool in request.tools]
            )
74
            error_check_ret = self.openai_serving_render.validate_chat_template(
75
76
77
78
79
80
81
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret

82
            _, engine_inputs = await self.openai_serving_render.preprocess_chat(
83
84
85
86
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
87
                default_template_kwargs=self.default_chat_template_kwargs,
88
                tool_dicts=tool_dicts,
89
                skip_mm_cache=True,
90
91
            )
        else:
92
            engine_inputs = await self.openai_serving_render.preprocess_completion(
93
94
95
                request,
                prompt_input=request.prompt,
                prompt_embeds=None,
96
                skip_mm_cache=True,
97
            )
98

99
        input_ids: list[int] = []
100
        for engine_input in engine_inputs:
101
            self._log_inputs(
102
                request_id,
103
                engine_input,
104
105
                params=None,
                lora_request=lora_request,
106
            )
107

108
            prompt_components = self._extract_prompt_components(engine_input)
109
110
            if prompt_components.token_ids is not None:
                input_ids.extend(prompt_components.token_ids)
111

112
113
        token_strs = None
        if request.return_token_strs:
114
            tokenizer = self.renderer.get_tokenizer()
115
116
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

117
118
119
120
        return TokenizeResponse(
            tokens=input_ids,
            token_strs=token_strs,
            count=len(input_ids),
121
            max_model_len=self.model_config.max_model_len,
122
        )
123
124

    async def create_detokenize(
125
126
        self,
        request: DetokenizeRequest,
127
        raw_request: Request,
128
    ) -> DetokenizeResponse | ErrorResponse:
129
130
131
132
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

133
        request_id = f"tokenize-{self._base_request_id(raw_request)}"
134

135
        lora_request = self._maybe_get_adapters(request)
136

137
        self._log_inputs(
138
            request_id,
139
            tokens_input(request.tokens),
140
141
            params=None,
            lora_request=lora_request,
142
        )
143

144
        tok_prompt = await self.renderer.tokenize_prompt_async(
145
146
            TokensPrompt(prompt_token_ids=request.tokens),
            request.build_tok_params(self.model_config),
147
        )
148
        prompt_text = tok_prompt["prompt"]  # type: ignore[typeddict-item]
149

150
        return DetokenizeResponse(prompt=prompt_text)
151
152

    async def get_tokenizer_info(
153
        self,
154
    ) -> TokenizerInfoResponse | ErrorResponse:
155
        """Get comprehensive tokenizer information."""
156
157
158
        tokenizer = self.renderer.get_tokenizer()
        info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
        return TokenizerInfoResponse(**info)
159
160
161
162


@dataclass
class TokenizerInfo:
163
    tokenizer: TokenizerLike
164
    chat_template: str | None
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj