serving.py 6.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass
4
from typing import Any, Final
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.engine.protocol import EngineClient
10
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
11
from vllm.entrypoints.logger import RequestLogger
12
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
13
from vllm.entrypoints.openai.engine.serving import OpenAIServing
14
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
15
from vllm.entrypoints.serve.tokenize.protocol import (
16
17
18
19
20
21
22
    DetokenizeRequest,
    DetokenizeResponse,
    TokenizeChatRequest,
    TokenizeRequest,
    TokenizeResponse,
    TokenizerInfoResponse,
)
23
from vllm.inputs import TokensPrompt
24
from vllm.logger import init_logger
25
from vllm.tokenizers import TokenizerLike
26

27
28
logger = init_logger(__name__)

29
30

class OpenAIServingTokenization(OpenAIServing):
31
32
    def __init__(
        self,
33
        engine_client: EngineClient,
34
        models: OpenAIServingModels,
35
        *,
36
37
        request_logger: RequestLogger | None,
        chat_template: str | None,
38
        chat_template_content_format: ChatTemplateContentFormatOption,
39
        trust_request_chat_template: bool = False,
40
        log_error_stack: bool = False,
41
    ) -> None:
42
43
44
45
46
47
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
48

49
50
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
51
        self.trust_request_chat_template = trust_request_chat_template
52

53
54
55
    async def create_tokenize(
        self,
        request: TokenizeRequest,
56
        raw_request: Request,
57
    ) -> TokenizeResponse | ErrorResponse:
58
59
60
61
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

62
        request_id = f"tokenize-{self._base_request_id(raw_request)}"
63

64
        try:
65
            lora_request = self._maybe_get_adapters(request)
66
67

            if isinstance(request, TokenizeChatRequest):
68
69
70
71
72
                tool_dicts = (
                    None
                    if request.tools is None
                    else [tool.model_dump() for tool in request.tools]
                )
73
74
75
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
76
                    trust_request_chat_template=self.trust_request_chat_template,
77
78
79
                )
                if error_check_ret is not None:
                    return error_check_ret
80
81

                _, engine_prompts = await self._preprocess_chat(
82
83
                    request,
                    request.messages,
84
85
86
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=None,
87
                    tool_dicts=tool_dicts,
88
89
                )
            else:
90
91
92
93
                engine_prompts = await self._preprocess_completion(
                    request,
                    prompt_input=request.prompt,
                    prompt_embeds=None,
94
                )
95
        except (ValueError, TypeError, jinja2.TemplateError) as e:
96
            logger.exception("Error in preprocessing prompt inputs")
97
            return self.create_error_response(f"{e} {e.__cause__}")
98

99
        input_ids: list[int] = []
100
        for engine_prompt in engine_prompts:
101
            self._log_inputs(
102
103
104
105
                request_id,
                engine_prompt,
                params=None,
                lora_request=lora_request,
106
            )
107

108
109
            if "prompt_token_ids" in engine_prompt:
                input_ids.extend(engine_prompt["prompt_token_ids"])  # type: ignore[typeddict-item]
110

111
112
        token_strs = None
        if request.return_token_strs:
113
            tokenizer = self.renderer.get_tokenizer()
114
115
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

116
117
118
119
120
121
        return TokenizeResponse(
            tokens=input_ids,
            token_strs=token_strs,
            count=len(input_ids),
            max_model_len=self.max_model_len,
        )
122
123

    async def create_detokenize(
124
125
        self,
        request: DetokenizeRequest,
126
        raw_request: Request,
127
    ) -> DetokenizeResponse | ErrorResponse:
128
129
130
131
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

132
        request_id = f"tokenize-{self._base_request_id(raw_request)}"
133

134
        lora_request = self._maybe_get_adapters(request)
135

136
        self._log_inputs(
137
138
139
140
            request_id,
            TokensPrompt(prompt_token_ids=request.tokens),
            params=None,
            lora_request=lora_request,
141
        )
142

143
144
145
        engine_prompt = await self.renderer.tokenize_prompt_async(
            TokensPrompt(prompt_token_ids=request.tokens),
            request.build_tok_params(self.model_config),
146
        )
147
        prompt_text = engine_prompt["prompt"]  # type: ignore[typeddict-item]
148

149
        return DetokenizeResponse(prompt=prompt_text)
150
151

    async def get_tokenizer_info(
152
        self,
153
    ) -> TokenizerInfoResponse | ErrorResponse:
154
155
        """Get comprehensive tokenizer information."""
        try:
156
            tokenizer = self.renderer.get_tokenizer()
157
158
159
            info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
            return TokenizerInfoResponse(**info)
        except Exception as e:
160
            return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
161
162
163
164


@dataclass
class TokenizerInfo:
165
    tokenizer: TokenizerLike
166
    chat_template: str | None
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

    def to_dict(self) -> dict[str, Any]:
        """Return the tokenizer configuration."""
        return self._get_tokenizer_config()

    def _get_tokenizer_config(self) -> dict[str, Any]:
        """Get tokenizer configuration directly from the tokenizer object."""
        config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})

        # Remove file path fields
        config.pop("vocab_file", None)
        config.pop("merges_file", None)

        config = self._make_json_serializable(config)
        config["tokenizer_class"] = type(self.tokenizer).__name__
        if self.chat_template:
            config["chat_template"] = self.chat_template
        return config

    def _make_json_serializable(self, obj):
        """Convert any non-JSON-serializable objects to serializable format."""
        if hasattr(obj, "content"):
            return obj.content
        elif isinstance(obj, dict):
            return {k: self._make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_json_serializable(item) for item in obj]
        else:
            return obj