"tests/v1/entrypoints/openai/test_completion.py" did not exist on "7a3d2a5b957063e911c25401c593dbc7798a5536"
serving_tokenization.py 5.83 KB
Newer Older
1
from typing import List, Optional, Union
2
3

from vllm.config import ModelConfig
4
from vllm.engine.protocol import EngineClient
5
6
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
                                         apply_mistral_chat_template,
7
                                         load_chat_template,
8
                                         parse_chat_messages_futures)
9
from vllm.entrypoints.logger import RequestLogger
10
11
# yapf conflicts with isort for this block
# yapf: disable
12
13
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
14
15
                                              ErrorResponse,
                                              TokenizeChatRequest,
16
17
                                              TokenizeRequest,
                                              TokenizeResponse)
18
# yapf: enable
19
20
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
                                                    LoRAModulePath,
21
                                                    OpenAIServing)
22
from vllm.logger import init_logger
23
from vllm.transformers_utils.tokenizer import MistralTokenizer
24
from vllm.utils import random_uuid
25

26
27
logger = init_logger(__name__)

28
29
30

class OpenAIServingTokenization(OpenAIServing):

31
32
    def __init__(
        self,
33
        engine_client: EngineClient,
34
        model_config: ModelConfig,
35
        base_model_paths: List[BaseModelPath],
36
37
38
39
40
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
    ):
41
        super().__init__(engine_client=engine_client,
42
                         model_config=model_config,
43
                         base_model_paths=base_model_paths,
44
45
46
                         lora_modules=lora_modules,
                         prompt_adapters=None,
                         request_logger=request_logger)
47

48
        # If this is None we use the tokenizer's default chat template
49
50
51
52
53
        # the list of commonly-used chat template names for HF named templates
        hf_chat_templates: List[str] = ['default', 'tool_use']
        self.chat_template = chat_template \
            if chat_template in hf_chat_templates \
            else load_chat_template(chat_template)
54

55
56
57
58
    async def create_tokenize(
        self,
        request: TokenizeRequest,
    ) -> Union[TokenizeResponse, ErrorResponse]:
59
60
61
62
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

63
        request_id = f"tokn-{random_uuid()}"
64

65
66
67
68
        (
            lora_request,
            prompt_adapter_request,
        ) = self._maybe_get_adapters(request)
69

70
        tokenizer = await self.engine_client.get_tokenizer(lora_request)
71

72
        prompt: Union[str, List[int]]
73
74
75
        if isinstance(request, TokenizeChatRequest):
            model_config = self.model_config

76
            conversation, mm_data_future = parse_chat_messages_futures(
77
                request.messages, model_config, tokenizer)
78

79
80
            mm_data = await mm_data_future
            if mm_data:
81
82
                logger.warning(
                    "Multi-modal inputs are ignored during tokenization")
83

84
85
86
87
88
89
            if isinstance(tokenizer, MistralTokenizer):
                prompt = apply_mistral_chat_template(
                    tokenizer,
                    messages=request.messages,
                    chat_template=self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
90
                    continue_final_message=request.continue_final_message,
91
92
93
94
95
96
97
                )
            else:
                prompt = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
98
                    continue_final_message=request.continue_final_message,
99
                )
100
101
102
103
104
105
106
107
        else:
            prompt = request.prompt

        self._log_inputs(request_id,
                         prompt,
                         params=None,
                         lora_request=lora_request,
                         prompt_adapter_request=prompt_adapter_request)
108

109
110
111
        # Silently ignore prompt adapter since it does not affect tokenization

        prompt_input = self._tokenize_prompt_input(
112
            request,
113
            tokenizer,
114
115
116
117
            prompt,
            add_special_tokens=request.add_special_tokens,
        )
        input_ids = prompt_input["prompt_token_ids"]
118
119
120
121
122
123

        return TokenizeResponse(tokens=input_ids,
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
124
125
126
        self,
        request: DetokenizeRequest,
    ) -> Union[DetokenizeResponse, ErrorResponse]:
127
128
129
130
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

131
132
133
134
135
136
137
        request_id = f"tokn-{random_uuid()}"

        (
            lora_request,
            prompt_adapter_request,
        ) = self._maybe_get_adapters(request)

138
        tokenizer = await self.engine_client.get_tokenizer(lora_request)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

        self._log_inputs(request_id,
                         request.tokens,
                         params=None,
                         lora_request=lora_request,
                         prompt_adapter_request=prompt_adapter_request)

        if prompt_adapter_request is not None:
            raise NotImplementedError("Prompt adapter is not supported "
                                      "for tokenization")

        prompt_input = self._tokenize_prompt_input(
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
156
157

        return DetokenizeResponse(prompt=input_text)