"vllm/vscode:/vscode.git/clone" did not exist on "2c9b4cf5bf844de0471f77e6579e16c7bc3ee0d0"
serving_tokenization.py 4.79 KB
Newer Older
1
from typing import List, Optional, Union
2
3
4

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
5
6
# yapf conflicts with isort for this block
# yapf: disable
7
8
9
from vllm.entrypoints.chat_utils import (ConversationMessage,
                                         load_chat_template,
                                         parse_chat_message_content)
10
from vllm.entrypoints.logger import RequestLogger
11
12
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
                                              DetokenizeResponse,
13
14
                                              ErrorResponse,
                                              TokenizeChatRequest,
15
16
                                              TokenizeRequest,
                                              TokenizeResponse)
17
# yapf: enable
18
19
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
20
from vllm.utils import random_uuid
21
22
23
24


class OpenAIServingTokenization(OpenAIServing):

25
26
27
28
29
30
31
32
33
34
    def __init__(
        self,
        engine: AsyncLLMEngine,
        model_config: ModelConfig,
        served_model_names: List[str],
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
    ):
35
36
37
        super().__init__(engine=engine,
                         model_config=model_config,
                         served_model_names=served_model_names,
38
39
40
                         lora_modules=lora_modules,
                         prompt_adapters=None,
                         request_logger=request_logger)
41

42
43
        # If this is None we use the tokenizer's default chat template
        self.chat_template = load_chat_template(chat_template)
44

45
46
47
48
    async def create_tokenize(
        self,
        request: TokenizeRequest,
    ) -> Union[TokenizeResponse, ErrorResponse]:
49
50
51
52
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

53
        request_id = f"tokn-{random_uuid()}"
54

55
56
57
58
        (
            lora_request,
            prompt_adapter_request,
        ) = self._maybe_get_adapters(request)
59

60
        tokenizer = await self.engine.get_tokenizer(lora_request)
61
62
63
64

        if isinstance(request, TokenizeChatRequest):
            model_config = self.model_config

65
66
67
            conversation: List[ConversationMessage] = []

            for message in request.messages:
68
                result = parse_chat_message_content(message, model_config,
69
70
                                                    tokenizer)
                conversation.extend(result.messages)
71

72
            prompt = tokenizer.apply_chat_template(
73
74
                add_generation_prompt=request.add_generation_prompt,
                conversation=conversation,
75
76
                tokenize=False,
                chat_template=self.chat_template)
77
78
79
80
81
82
83
84
85
            assert isinstance(prompt, str)
        else:
            prompt = request.prompt

        self._log_inputs(request_id,
                         prompt,
                         params=None,
                         lora_request=lora_request,
                         prompt_adapter_request=prompt_adapter_request)
86

87
88
89
        # Silently ignore prompt adapter since it does not affect tokenization

        prompt_input = self._tokenize_prompt_input(
90
            request,
91
            tokenizer,
92
93
94
95
            prompt,
            add_special_tokens=request.add_special_tokens,
        )
        input_ids = prompt_input["prompt_token_ids"]
96
97
98
99
100
101

        return TokenizeResponse(tokens=input_ids,
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
102
103
104
        self,
        request: DetokenizeRequest,
    ) -> Union[DetokenizeResponse, ErrorResponse]:
105
106
107
108
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

109
110
111
112
113
114
115
        request_id = f"tokn-{random_uuid()}"

        (
            lora_request,
            prompt_adapter_request,
        ) = self._maybe_get_adapters(request)

116
        tokenizer = await self.engine.get_tokenizer(lora_request)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

        self._log_inputs(request_id,
                         request.tokens,
                         params=None,
                         lora_request=lora_request,
                         prompt_adapter_request=prompt_adapter_request)

        if prompt_adapter_request is not None:
            raise NotImplementedError("Prompt adapter is not supported "
                                      "for tokenization")

        prompt_input = self._tokenize_prompt_input(
            request,
            tokenizer,
            request.tokens,
        )
        input_text = prompt_input["prompt"]
134
135

        return DetokenizeResponse(prompt=input_text)