"vllm/vscode:/vscode.git/clone" did not exist on "b72af8f1eded6f5838be29eb6093ab0e0e0c240c"
serving_engine.py 20.1 KB
Newer Older
1
import json
2
from concurrent.futures.thread import ThreadPoolExecutor
3
from http import HTTPStatus
4
5
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
                    Optional, Sequence, Tuple, TypedDict, Union)
6

7
from fastapi import Request
8
from pydantic import Field
9
from starlette.datastructures import Headers
10
from typing_extensions import Annotated
11

12
from vllm.config import ModelConfig
13
from vllm.engine.protocol import EngineClient
14
15
# yapf conflicts with isort for this block
# yapf: disable
16
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
17
                                         ChatTemplateContentFormatOption,
18
19
20
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
21
22
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
23
from vllm.entrypoints.logger import RequestLogger
24
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
25
                                              CompletionRequest,
26
                                              DetokenizeRequest,
27
28
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
29
30
                                              ErrorResponse, RerankRequest,
                                              ScoreRequest,
31
                                              TokenizeChatRequest,
32
33
                                              TokenizeCompletionRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
34
from vllm.entrypoints.openai.tool_parsers import ToolParser
35
# yapf: enable
36
from vllm.inputs import TokensPrompt
37
from vllm.inputs.parse import parse_and_batch_prompt
38
from vllm.logger import init_logger
39
from vllm.lora.request import LoRARequest
40
from vllm.pooling_params import PoolingParams
41
from vllm.prompt_adapter.request import PromptAdapterRequest
42
from vllm.sampling_params import BeamSearchParams, SamplingParams
43
from vllm.sequence import Logprob
44
45
46
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
47
from vllm.utils import is_list_of, make_async, random_uuid
48
49
50

logger = init_logger(__name__)

51
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
52
                              EmbeddingCompletionRequest, ScoreRequest,
53
54
55
56
57
58
                              TokenizeCompletionRequest]

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]

AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
59
60
61
62
63
64
65


class TextTokensPrompt(TypedDict):
    prompt: str
    prompt_token_ids: List[int]


66
67
68
RequestPrompt = Union[List[int], str, TextTokensPrompt]


69
70
class OpenAIServing:

71
72
    def __init__(
        self,
73
        engine_client: EngineClient,
74
        model_config: ModelConfig,
75
        models: OpenAIServingModels,
76
77
        *,
        request_logger: Optional[RequestLogger],
78
        return_tokens_as_token_ids: bool = False,
79
    ):
80
81
        super().__init__()

82
        self.engine_client = engine_client
83
        self.model_config = model_config
84
85
        self.max_model_len = model_config.max_model_len

86
        self.models = models
87

88
        self.request_logger = request_logger
89
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
90

91
92
93
94
95
96
97
98
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

        self._tokenize_prompt_input_async = make_async(
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
        self._tokenize_prompt_input_or_inputs_async = make_async(
            self._tokenize_prompt_input_or_inputs,
            executor=self._tokenizer_executor)

99
100
101
102
103
104
105
106
107
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

108
109
110
111
112
113
114
115
116
117
118
119
120
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

121
    async def _check_model(
122
123
        self,
        request: AnyRequest,
124
    ) -> Optional[ErrorResponse]:
125
        if self._is_model_supported(request.model):
126
            return None
127
128
129
        if request.model in [
                lora.lora_name for lora in self.models.lora_requests
        ]:
130
            return None
131
132
        if request.model in [
                prompt_adapter.prompt_adapter_name
133
                for prompt_adapter in self.models.prompt_adapter_requests
134
135
        ]:
            return None
136
137
138
139
140
        return self.create_error_response(
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

141
142
143
144
    def _maybe_get_adapters(
        self, request: AnyRequest
    ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
            None, PromptAdapterRequest]]:
145
        if self._is_model_supported(request.model):
146
            return None, None
147
        for lora in self.models.lora_requests:
148
            if request.model == lora.lora_name:
149
                return lora, None
150
        for prompt_adapter in self.models.prompt_adapter_requests:
151
            if request.model == prompt_adapter.prompt_adapter_name:
152
                return None, prompt_adapter
153
        # if _check_model has been called earlier, this will be unreachable
154
        raise ValueError(f"The model `{request.model}` does not exist.")
155

156
157
158
159
160
161
162
163
    def _normalize_prompt_text_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
164
165
166
167
168
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

169
170
        if truncate_prompt_tokens is None:
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
171
        else:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=truncate_prompt_tokens)

        input_ids = encoded.input_ids

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

    def _normalize_prompt_tokens_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt_ids: List[int],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
        if truncate_prompt_tokens is None:
191
            input_ids = prompt_ids
192
193
194
195
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

        input_text = tokenizer.decode(input_ids)
196

197
198
199
200
201
202
203
204
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
        input_ids: List[int],
        input_text: str,
    ) -> TextTokensPrompt:
205
206
        token_num = len(input_ids)

207
        # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
208
209
210
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
                       ScoreRequest, RerankRequest)):
211
212
213

            operation = "score" if isinstance(request, ScoreRequest) \
                else "embedding generation"
214
215
216
217
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
218
219
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
220
221
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
222

223
224
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
225
226
227
228
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
229

230
231
232
233
234
235
236
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
            max_tokens = request.max_tokens
        if max_tokens is None:
237
238
239
240
241
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
242
                    f"Please reduce the length of the messages.")
243
        elif token_num + max_tokens > self.max_model_len:
244
            raise ValueError(
245
246
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
247
                f"{max_tokens + token_num} tokens "
248
                f"({token_num} in the messages, "
249
                f"{max_tokens} in the completion). "
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

    def _tokenize_prompt_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt_input: Union[str, List[int]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes single input.
        """
        return next(
            self._tokenize_prompt_inputs(
                request,
                tokenizer,
                [prompt_input],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            ))

    def _tokenize_prompt_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt_inputs: Iterable[Union[str, List[int]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> Iterator[TextTokensPrompt]:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
                yield self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
                yield self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

    def _tokenize_prompt_input_or_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
311
    ) -> List[TextTokensPrompt]:
312
313
314
315
316
317
318
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
        # "is True" is required for Pyright to perform type narrowing
        # See: https://github.com/microsoft/pyright/issues/7672
        return [
            self._normalize_prompt_text_to_input(
                request,
                tokenizer,
                prompt=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens)
            if prompt_input["is_tokens"] is False else
            self._normalize_prompt_tokens_to_input(
                request,
                tokenizer,
                prompt_ids=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens)
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
        ]
338

339
    async def _preprocess_completion(
340
341
342
343
344
345
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
346
347
348
349
350
351
352
353
    ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
        request_prompts = await self._tokenize_prompt_input_or_inputs_async(
            request,
            tokenizer,
            input_or_inputs,
            truncate_prompt_tokens=truncate_prompt_tokens,
            add_special_tokens=add_special_tokens,
        )
354
355
356
357
358
359
360
361
362
363
364
365
366

        engine_prompts = [
            TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
            for request_prompt in request_prompts
        ]

        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
        request: ChatLikeRequest,
        tokenizer: AnyTokenizer,
        messages: List[ChatCompletionMessageParam],
367
368
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
369
370
371
372
373
374
375
376
377
378
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tool_dicts: Optional[List[Dict[str, Any]]] = None,
        documents: Optional[List[Dict[str, str]]] = None,
        chat_template_kwargs: Optional[Dict[str, Any]] = None,
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
    ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
               List[TokensPrompt]]:
379
380
381
382
383
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )
384
385
386
387
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
            self.model_config,
            tokenizer,
388
            content_format=resolved_content_format,
389
390
        )

391
392
393
394
395
396
397
398
399
        _chat_template_kwargs: Dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

400
401
402
403
404
405
        request_prompt: Union[str, List[int]]
        is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
        if is_mistral_tokenizer:
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
406
                **_chat_template_kwargs,
407
408
409
410
411
            )
        else:
            request_prompt = apply_hf_chat_template(
                tokenizer,
                conversation=conversation,
412
                **_chat_template_kwargs,
413
414
415
416
            )

        mm_data = await mm_data_future

417
418
419
420
421
422
423
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
424
425
426
427
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

428
429
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
430
431

        if isinstance(request_prompt, str):
432
            prompt_inputs = await self._tokenize_prompt_input_async(
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

        engine_prompt = TokensPrompt(
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data

        return conversation, [request_prompt], [engine_prompt]

454
455
456
    def _log_inputs(
        self,
        request_id: str,
457
        inputs: RequestPrompt,
458
459
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
460
461
462
463
464
465
466
467
468
469
470
471
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return

        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = None
        elif isinstance(inputs, list):
            prompt = None
            prompt_token_ids = inputs
472
        else:
473
474
475
476
477
478
479
480
481
482
483
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
484

485
486
487
488
489
490
491
492
493
494
495
496
497
498
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

499
    @staticmethod
500
    def _base_request_id(raw_request: Optional[Request],
501
502
503
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
504
505
506
507
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
508

509
    @staticmethod
510
511
512
513
514
515
516
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

517
518
        if logprob.decoded_token is not None:
            return logprob.decoded_token
519
        return tokenizer.decode(token_id)
520

521
    def _is_model_supported(self, model_name):
522
        return self.models.is_base_model(model_name)