serving_tokenization.py 5.95 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Final, Optional, Union
5

6
import jinja2
7
8
from fastapi import Request

9
from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
12
from vllm.entrypoints.logger import RequestLogger
13
14
# yapf conflicts with isort for this block
# yapf: disable
15
16
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
17
18
                                              ErrorResponse,
                                              TokenizeChatRequest,
19
20
                                              TokenizeRequest,
                                              TokenizeResponse)
21
# yapf: enable
22
23
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
24
from vllm.logger import init_logger
25

26
27
logger = init_logger(__name__)

28
29
30

class OpenAIServingTokenization(OpenAIServing):

31
32
    def __init__(
        self,
33
        engine_client: EngineClient,
34
        model_config: ModelConfig,
35
        models: OpenAIServingModels,
36
37
38
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
39
40
        chat_template_content_format: ChatTemplateContentFormatOption,
    ) -> None:
41
        super().__init__(engine_client=engine_client,
42
                         model_config=model_config,
43
                         models=models,
44
                         request_logger=request_logger)
45

46
47
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
48

49
50
51
    async def create_tokenize(
        self,
        request: TokenizeRequest,
52
        raw_request: Request,
53
    ) -> Union[TokenizeResponse, ErrorResponse]:
54
55
56
57
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

58
        request_id = f"tokn-{self._base_request_id(raw_request)}"
59

60
61
62
63
64
65
66
67
68
        try:
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            tokenizer = await self.engine_client.get_tokenizer(lora_request)

            if isinstance(request, TokenizeChatRequest):
69
70
                tool_dicts = (None if request.tools is None else
                              [tool.model_dump() for tool in request.tools])
71
72
73
74
75
76
                (
                    _,
                    request_prompts,
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
77
                    tokenizer,
78
                    request.messages,
79
                    tool_dicts=tool_dicts,
80
81
82
                    chat_template=request.chat_template or self.chat_template,
                    chat_template_content_format=self.
                    chat_template_content_format,
83
                    add_generation_prompt=request.add_generation_prompt,
84
                    continue_final_message=request.continue_final_message,
85
                    chat_template_kwargs=request.chat_template_kwargs,
86
                    add_special_tokens=request.add_special_tokens,
87
88
                )
            else:
89
90
91
92
93
94
95
                (request_prompts,
                 engine_prompts) = await self._preprocess_completion(
                     request,
                     tokenizer,
                     request.prompt,
                     add_special_tokens=request.add_special_tokens,
                 )
96
        except (ValueError, TypeError, jinja2.TemplateError) as e:
97
            logger.exception("Error in preprocessing prompt inputs")
98
            return self.create_error_response(f"{e} {e.__cause__}")
99

100
        input_ids: list[int] = []
101
102
103
104
105
106
        for i, engine_prompt in enumerate(engine_prompts):
            self._log_inputs(request_id,
                             request_prompts[i],
                             params=None,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)
107

108
109
            # Silently ignore prompt adapter since it does not affect
            # tokenization (Unlike in Embeddings API where an error is raised)
110
111
112
            if isinstance(engine_prompt,
                          dict) and "prompt_token_ids" in engine_prompt:
                input_ids.extend(engine_prompt["prompt_token_ids"])
113

114
115
116
117
        token_strs = None
        if request.return_token_strs:
            token_strs = tokenizer.convert_ids_to_tokens(input_ids)

118
        return TokenizeResponse(tokens=input_ids,
119
                                token_strs=token_strs,
120
121
122
123
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
124
125
        self,
        request: DetokenizeRequest,
126
        raw_request: Request,
127
    ) -> Union[DetokenizeResponse, ErrorResponse]:
128
129
130
131
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

132
        request_id = f"tokn-{self._base_request_id(raw_request)}"
133
134
135
136
137
138

        (
            lora_request,
            prompt_adapter_request,
        ) = self._maybe_get_adapters(request)

139
        tokenizer = await self.engine_client.get_tokenizer(lora_request)
140
141
142
143
144
145
146

        self._log_inputs(request_id,
                         request.tokens,
                         params=None,
                         lora_request=lora_request,
                         prompt_adapter_request=prompt_adapter_request)

147
148
        # Silently ignore prompt adapter since it does not affect tokenization
        # (Unlike in Embeddings API where an error is raised)
149

150
        prompt_input = await self._tokenize_prompt_input_async(
151
152
153
154
155
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
156
157

        return DetokenizeResponse(prompt=input_text)