serving_tokenization.py 5.36 KB
Newer Older
1
from typing import List, Optional, Union
2
3

from vllm.config import ModelConfig
4
from vllm.engine.protocol import EngineClient
5
from vllm.entrypoints.chat_utils import load_chat_template
6
from vllm.entrypoints.logger import RequestLogger
7
8
# yapf conflicts with isort for this block
# yapf: disable
9
10
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
11
12
                                              ErrorResponse,
                                              TokenizeChatRequest,
13
14
                                              TokenizeRequest,
                                              TokenizeResponse)
15
# yapf: enable
16
17
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
                                                    LoRAModulePath,
18
                                                    OpenAIServing)
19
from vllm.logger import init_logger
20
from vllm.utils import random_uuid
21

22
23
logger = init_logger(__name__)

24
25
26

class OpenAIServingTokenization(OpenAIServing):

27
28
    def __init__(
        self,
29
        engine_client: EngineClient,
30
        model_config: ModelConfig,
31
        base_model_paths: List[BaseModelPath],
32
33
34
35
36
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
    ):
37
        super().__init__(engine_client=engine_client,
38
                         model_config=model_config,
39
                         base_model_paths=base_model_paths,
40
41
42
                         lora_modules=lora_modules,
                         prompt_adapters=None,
                         request_logger=request_logger)
43

44
        # If this is None we use the tokenizer's default chat template
45
46
47
48
49
        # the list of commonly-used chat template names for HF named templates
        hf_chat_templates: List[str] = ['default', 'tool_use']
        self.chat_template = chat_template \
            if chat_template in hf_chat_templates \
            else load_chat_template(chat_template)
50

51
52
53
54
    async def create_tokenize(
        self,
        request: TokenizeRequest,
    ) -> Union[TokenizeResponse, ErrorResponse]:
55
56
57
58
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

59
        request_id = f"tokn-{random_uuid()}"
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        try:
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            tokenizer = await self.engine_client.get_tokenizer(lora_request)

            if isinstance(request, TokenizeChatRequest):
                (
                    _,
                    request_prompts,
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
76
                    tokenizer,
77
                    request.messages,
78
79
                    chat_template=self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
80
                    continue_final_message=request.continue_final_message,
81
                    add_special_tokens=request.add_special_tokens,
82
83
                )
            else:
84
85
                request_prompts, engine_prompts = self._preprocess_completion(
                    request,
86
                    tokenizer,
87
88
                    request.prompt,
                    add_special_tokens=request.add_special_tokens,
89
                )
90
91
92
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
93

94
95
96
97
98
99
100
        input_ids: List[int] = []
        for i, engine_prompt in enumerate(engine_prompts):
            self._log_inputs(request_id,
                             request_prompts[i],
                             params=None,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)
101

102
103
            # Silently ignore prompt adapter since it does not affect
            # tokenization (Unlike in Embeddings API where an error is raised)
104

105
            input_ids.extend(engine_prompt["prompt_token_ids"])
106
107
108
109
110
111

        return TokenizeResponse(tokens=input_ids,
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
112
113
114
        self,
        request: DetokenizeRequest,
    ) -> Union[DetokenizeResponse, ErrorResponse]:
115
116
117
118
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

119
120
121
122
123
124
125
        request_id = f"tokn-{random_uuid()}"

        (
            lora_request,
            prompt_adapter_request,
        ) = self._maybe_get_adapters(request)

126
        tokenizer = await self.engine_client.get_tokenizer(lora_request)
127
128
129
130
131
132
133

        self._log_inputs(request_id,
                         request.tokens,
                         params=None,
                         lora_request=lora_request,
                         prompt_adapter_request=prompt_adapter_request)

134
135
        # Silently ignore prompt adapter since it does not affect tokenization
        # (Unlike in Embeddings API where an error is raised)
136
137
138
139
140
141
142

        prompt_input = self._tokenize_prompt_input(
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
143
144

        return DetokenizeResponse(prompt=input_text)