serving_engine.py 10.2 KB
Newer Older
1
import json
2
import pathlib
3
from dataclasses import dataclass
4
from http import HTTPStatus
5
from typing import Any, Dict, List, Optional, Tuple, Union
6

7
8
from pydantic import Field
from typing_extensions import Annotated
9

10
from vllm.config import ModelConfig
11
from vllm.engine.async_llm_engine import AsyncLLMEngine
12
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
13
                                              CompletionRequest,
14
                                              DetokenizeRequest,
15
                                              EmbeddingRequest, ErrorResponse,
16
                                              ModelCard, ModelList,
17
                                              ModelPermission, TokenizeRequest)
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
from vllm.prompt_adapter.request import PromptAdapterRequest
21
from vllm.sequence import Logprob
22
from vllm.transformers_utils.tokenizer import get_tokenizer
23
24
25
26

logger = init_logger(__name__)


27
28
29
30
31
32
@dataclass
class PromptAdapterPath:
    name: str
    local_path: str


33
@dataclass
34
class LoRAModulePath:
35
36
37
38
    name: str
    local_path: str


39
40
class OpenAIServing:

41
42
43
44
45
46
47
48
    def __init__(
        self,
        engine: AsyncLLMEngine,
        model_config: ModelConfig,
        served_model_names: List[str],
        lora_modules: Optional[List[LoRAModulePath]],
        prompt_adapters: Optional[List[PromptAdapterPath]] = None,
    ):
49
50
        super().__init__()

51
        self.engine = engine
52
        self.model_config = model_config
53
54
55
56
57
58
59
60
61
62
        self.max_model_len = model_config.max_model_len

        # A separate tokenizer to map token IDs to strings.
        self.tokenizer = get_tokenizer(
            model_config.tokenizer,
            tokenizer_mode=model_config.tokenizer_mode,
            tokenizer_revision=model_config.tokenizer_revision,
            trust_remote_code=model_config.trust_remote_code,
            truncation_side="left")

63
        self.served_model_names = served_model_names
64

65
66
        self.lora_requests = []
        if lora_modules is not None:
67
68
69
70
71
72
73
            self.lora_requests = [
                LoRARequest(
                    lora_name=lora.name,
                    lora_int_id=i,
                    lora_local_path=lora.local_path,
                ) for i, lora in enumerate(lora_modules, start=1)
            ]
74

75
76
77
        self.prompt_adapter_requests = []
        if prompt_adapters is not None:
            for i, prompt_adapter in enumerate(prompt_adapters, start=1):
78
79
                with pathlib.Path(prompt_adapter.local_path,
                                  "adapter_config.json").open() as f:
80
81
82
83
84
85
86
87
88
                    adapter_config = json.load(f)
                    num_virtual_tokens = adapter_config["num_virtual_tokens"]
                self.prompt_adapter_requests.append(
                    PromptAdapterRequest(
                        prompt_adapter_name=prompt_adapter.name,
                        prompt_adapter_id=i,
                        prompt_adapter_local_path=prompt_adapter.local_path,
                        prompt_adapter_num_virtual_tokens=num_virtual_tokens))

89
90
91
    async def show_available_models(self) -> ModelList:
        """Show available models. Right now we only have one model."""
        model_cards = [
92
            ModelCard(id=served_model_name,
93
                      max_model_len=self.max_model_len,
94
                      root=self.served_model_names[0],
95
                      permission=[ModelPermission()])
96
            for served_model_name in self.served_model_names
97
        ]
98
99
        lora_cards = [
            ModelCard(id=lora.lora_name,
100
                      root=self.served_model_names[0],
101
102
103
                      permission=[ModelPermission()])
            for lora in self.lora_requests
        ]
104
105
106
107
108
109
        prompt_adapter_cards = [
            ModelCard(id=prompt_adapter.prompt_adapter_name,
                      root=self.served_model_names[0],
                      permission=[ModelPermission()])
            for prompt_adapter in self.prompt_adapter_requests
        ]
110
        model_cards.extend(lora_cards)
111
        model_cards.extend(prompt_adapter_cards)
112
113
114
115
116
117
118
119
120
121
122
        return ModelList(data=model_cards)

    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

123
124
125
126
127
128
129
130
131
132
133
134
135
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

136
    async def _check_model(
137
138
139
        self, request: Union[ChatCompletionRequest, CompletionRequest,
                             DetokenizeRequest, EmbeddingRequest,
                             TokenizeRequest]
140
    ) -> Optional[ErrorResponse]:
141
        if request.model in self.served_model_names:
142
            return None
143
        if request.model in [lora.lora_name for lora in self.lora_requests]:
144
            return None
145
146
147
148
149
        if request.model in [
                prompt_adapter.prompt_adapter_name
                for prompt_adapter in self.prompt_adapter_requests
        ]:
            return None
150
151
152
153
154
        return self.create_error_response(
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

155
    def _maybe_get_adapter(
156
157
        self, request: Union[CompletionRequest, ChatCompletionRequest,
                             EmbeddingRequest]
158
159
    ) -> Tuple[Optional[str], Optional[Union[LoRARequest,
                                             PromptAdapterRequest]]]:
160
        if request.model in self.served_model_names:
161
            return None, None
162
163
        for lora in self.lora_requests:
            if request.model == lora.lora_name:
164
165
166
167
                return 'LoRA', lora
        for prompt_adapter in self.prompt_adapter_requests:
            if request.model == prompt_adapter.prompt_adapter_name:
                return 'PromptAdapter', prompt_adapter
168
        # if _check_model has been called earlier, this will be unreachable
169
        raise ValueError(f"The model `{request.model}` does not exist.")
170

171
    def _validate_prompt_and_tokenize(
172
173
            self,
            request: Union[ChatCompletionRequest, CompletionRequest,
174
175
                           DetokenizeRequest, EmbeddingRequest,
                           TokenizeRequest],
176
177
178
179
            prompt: Optional[str] = None,
            prompt_ids: Optional[List[int]] = None,
            truncate_prompt_tokens: Optional[Annotated[int,
                                                       Field(ge=1)]] = None,
180
181
            add_special_tokens: Optional[bool] = True
    ) -> Tuple[List[int], str]:
182
183
184
185
186
187
        if not (prompt or prompt_ids):
            raise ValueError("Either prompt or prompt_ids should be provided.")
        if (prompt and prompt_ids):
            raise ValueError(
                "Only one of prompt or prompt_ids should be provided.")

188
        if prompt_ids is None:
189
190
191
192
193
194
            # When using OpenAIServingChat for chat completions, for
            # most models the special tokens (e.g., BOS) have already
            # been added by the chat template. Therefore, we do not
            # need to add them again.
            # Set add_special_tokens to False (by default) to avoid
            # adding the BOS tokens again.
195
196
            tokenizer_kwargs: Dict[str, Any] = {
                "add_special_tokens": add_special_tokens
197
            }
198
199
200
201
202
            if truncate_prompt_tokens is not None:
                tokenizer_kwargs.update({
                    "truncation": True,
                    "max_length": truncate_prompt_tokens,
                })
203
204
205
206
207
208
            input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
        elif truncate_prompt_tokens is not None:
            input_ids = prompt_ids[-truncate_prompt_tokens:]
        else:
            input_ids = prompt_ids

209
210
        input_text = prompt if prompt is not None else self.tokenizer.decode(
            prompt_ids)
211
212
        token_num = len(input_ids)

213
214
215
216
217
218
219
220
221
222
        # Note: EmbeddingRequest doesn't have max_tokens
        if isinstance(request, EmbeddingRequest):
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the input for embedding "
                    f"generation. Please reduce the length of the input.", )
            return input_ids, input_text

223
224
225
226
227
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
        if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
            return input_ids, input_text

228
        if request.max_tokens is None:
229
230
231
232
233
234
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
                    f"Please reduce the length of the messages.", )
235
            request.max_tokens = self.max_model_len - token_num
236

237
        if token_num + request.max_tokens > self.max_model_len:
238
            raise ValueError(
239
240
241
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
                f"{request.max_tokens + token_num} tokens "
242
243
244
245
                f"({token_num} in the messages, "
                f"{request.max_tokens} in the completion). "
                f"Please reduce the length of the messages or completion.", )
        else:
246
            return input_ids, input_text
247
248
249
250
251

    def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
        if logprob.decoded_token is not None:
            return logprob.decoded_token
        return self.tokenizer.decode(token_id)