serving_engine.py 21.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
from collections.abc import Iterable, Iterator, Mapping, Sequence
5
from concurrent.futures.thread import ThreadPoolExecutor
6
from http import HTTPStatus
7
from typing import Annotated, Any, Callable, Optional, TypedDict, Union
8

9
from fastapi import Request
10
from pydantic import Field
11
from starlette.datastructures import Headers
12

13
import vllm.envs as envs
14
from vllm.config import ModelConfig
15
from vllm.engine.protocol import EngineClient
16
17
# yapf conflicts with isort for this block
# yapf: disable
18
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
19
                                         ChatTemplateContentFormatOption,
20
21
22
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
23
24
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
27
                                              CompletionRequest,
28
                                              DetokenizeRequest,
29
30
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
31
32
                                              ErrorResponse, RerankRequest,
                                              ScoreRequest,
33
                                              TokenizeChatRequest,
34
35
                                              TokenizeCompletionRequest,
                                              TranscriptionRequest)
36
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
37
from vllm.entrypoints.openai.tool_parsers import ToolParser
38
# yapf: enable
39
from vllm.inputs import TokensPrompt
40
from vllm.inputs.parse import parse_and_batch_prompt
41
from vllm.logger import init_logger
42
from vllm.lora.request import LoRARequest
43
from vllm.pooling_params import PoolingParams
44
from vllm.prompt_adapter.request import PromptAdapterRequest
45
from vllm.sampling_params import BeamSearchParams, SamplingParams
46
from vllm.sequence import Logprob, PromptLogprobs
47
48
49
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
50
from vllm.utils import is_list_of, make_async, random_uuid
51
52
53

logger = init_logger(__name__)

54
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
55
56
                              EmbeddingCompletionRequest, RerankRequest,
                              ScoreRequest, TokenizeCompletionRequest]
57
58
59
60

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]

61
62
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
                   TranscriptionRequest]
63
64
65
66


class TextTokensPrompt(TypedDict):
    prompt: str
67
    prompt_token_ids: list[int]
68
69


70
RequestPrompt = Union[list[int], str, TextTokensPrompt]
71
72


73
74
class OpenAIServing:

75
76
    def __init__(
        self,
77
        engine_client: EngineClient,
78
        model_config: ModelConfig,
79
        models: OpenAIServingModels,
80
81
        *,
        request_logger: Optional[RequestLogger],
82
        return_tokens_as_token_ids: bool = False,
83
    ):
84
85
        super().__init__()

86
        self.engine_client = engine_client
87
        self.model_config = model_config
88
89
        self.max_model_len = model_config.max_model_len

90
        self.models = models
91

92
        self.request_logger = request_logger
93
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
94

95
96
97
98
99
100
101
102
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

        self._tokenize_prompt_input_async = make_async(
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
        self._tokenize_prompt_input_or_inputs_async = make_async(
            self._tokenize_prompt_input_or_inputs,
            executor=self._tokenizer_executor)

103
104
105
106
107
108
109
110
111
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

112
113
114
115
116
117
118
119
120
121
122
123
124
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

125
    async def _check_model(
126
127
        self,
        request: AnyRequest,
128
    ) -> Optional[ErrorResponse]:
129
130
131

        error_response = None

132
        if self._is_model_supported(request.model):
133
            return None
134
135
136
        if request.model in [
                lora.lora_name for lora in self.models.lora_requests
        ]:
137
            return None
138
139
140
141
142
143
144
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
                load_result := await self.models.resolve_lora(request.model)):
            if isinstance(load_result, LoRARequest):
                return None
            if isinstance(load_result, ErrorResponse) and \
                load_result.code == HTTPStatus.BAD_REQUEST.value:
                error_response = load_result
145
146
        if request.model in [
                prompt_adapter.prompt_adapter_name
147
                for prompt_adapter in self.models.prompt_adapter_requests
148
149
        ]:
            return None
150
151

        return error_response or self.create_error_response(
152
153
154
155
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

156
157
    def _maybe_get_adapters(
        self, request: AnyRequest
158
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
159
            None, PromptAdapterRequest]]:
160
        if self._is_model_supported(request.model):
161
            return None, None
162
        for lora in self.models.lora_requests:
163
            if request.model == lora.lora_name:
164
                return lora, None
165
        for prompt_adapter in self.models.prompt_adapter_requests:
166
            if request.model == prompt_adapter.prompt_adapter_name:
167
                return None, prompt_adapter
168
        # if _check_model has been called earlier, this will be unreachable
169
        raise ValueError(f"The model `{request.model}` does not exist.")
170

171
172
173
174
175
    def _normalize_prompt_text_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
176
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
177
178
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
179
180
181
182
183
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

184
185
        if truncate_prompt_tokens is None:
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
186
        else:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=truncate_prompt_tokens)

        input_ids = encoded.input_ids

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

    def _normalize_prompt_tokens_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
202
        prompt_ids: list[int],
203
204
205
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
        if truncate_prompt_tokens is None:
206
            input_ids = prompt_ids
207
208
209
210
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

        input_text = tokenizer.decode(input_ids)
211

212
213
214
215
216
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
217
        input_ids: list[int],
218
219
        input_text: str,
    ) -> TextTokensPrompt:
220
221
        token_num = len(input_ids)

222
        # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
223
224
225
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
                       ScoreRequest, RerankRequest)):
226
227
228

            operation = "score" if isinstance(request, ScoreRequest) \
                else "embedding generation"
229
230
231
232
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
233
234
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
235
236
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
237

238
239
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
240
241
242
243
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
244

245
246
247
248
249
250
251
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
            max_tokens = request.max_tokens
        if max_tokens is None:
252
253
254
255
256
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
257
                    f"Please reduce the length of the messages.")
258
        elif token_num + max_tokens > self.max_model_len:
259
            raise ValueError(
260
261
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
262
                f"{max_tokens + token_num} tokens "
263
                f"({token_num} in the messages, "
264
                f"{max_tokens} in the completion). "
265
266
267
268
269
270
271
272
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

    def _tokenize_prompt_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
273
        prompt_input: Union[str, list[int]],
274
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
275
276
277
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
278
        A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs`
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        that assumes single input.
        """
        return next(
            self._tokenize_prompt_inputs(
                request,
                tokenizer,
                [prompt_input],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            ))

    def _tokenize_prompt_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
294
        prompt_inputs: Iterable[Union[str, list[int]]],
295
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
296
297
298
        add_special_tokens: bool = True,
    ) -> Iterator[TextTokensPrompt]:
        """
299
        A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs`
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
                yield self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
                yield self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

    def _tokenize_prompt_input_or_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
323
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
324
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
325
        add_special_tokens: bool = True,
326
    ) -> list[TextTokensPrompt]:
327
328
329
330
331
332
333
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
        # "is True" is required for Pyright to perform type narrowing
        # See: https://github.com/microsoft/pyright/issues/7672
        return [
            self._normalize_prompt_text_to_input(
                request,
                tokenizer,
                prompt=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens)
            if prompt_input["is_tokens"] is False else
            self._normalize_prompt_tokens_to_input(
                request,
                tokenizer,
                prompt_ids=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens)
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
        ]
353

354
    async def _preprocess_completion(
355
356
357
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
358
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
359
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
360
        add_special_tokens: bool = True,
361
    ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
362
363
364
365
366
367
368
        request_prompts = await self._tokenize_prompt_input_or_inputs_async(
            request,
            tokenizer,
            input_or_inputs,
            truncate_prompt_tokens=truncate_prompt_tokens,
            add_special_tokens=add_special_tokens,
        )
369
370
371
372
373
374
375
376
377
378
379
380

        engine_prompts = [
            TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
            for request_prompt in request_prompts
        ]

        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
        request: ChatLikeRequest,
        tokenizer: AnyTokenizer,
381
        messages: list[ChatCompletionMessageParam],
382
383
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
384
385
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
386
387
388
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
389
390
391
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
392
393
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
               list[TokensPrompt]]:
394
395
        model_config = self.model_config

396
        resolved_content_format = resolve_chat_template_content_format(
397
            model_config,
398
            chat_template,
399
            tool_dicts,
400
401
402
            chat_template_content_format,
            tokenizer,
        )
403
404
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
405
            model_config,
406
            tokenizer,
407
            content_format=resolved_content_format,
408
409
        )

410
        _chat_template_kwargs: dict[str, Any] = dict(
411
412
413
414
415
416
417
418
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

419
        request_prompt: Union[str, list[int]]
420
        if isinstance(tokenizer, MistralTokenizer):
421
422
423
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
424
                **_chat_template_kwargs,
425
426
427
            )
        else:
            request_prompt = apply_hf_chat_template(
428
                model_config,
429
430
                tokenizer,
                conversation=conversation,
431
                **_chat_template_kwargs,
432
433
434
435
            )

        mm_data = await mm_data_future

436
437
438
439
440
441
442
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
443
444
445
446
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

447
448
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
449
450

        if isinstance(request_prompt, str):
451
            prompt_inputs = await self._tokenize_prompt_input_async(
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

        engine_prompt = TokensPrompt(
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
470
471
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
472

473
474
475
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

476
477
        return conversation, [request_prompt], [engine_prompt]

478
479
480
    def _log_inputs(
        self,
        request_id: str,
481
        inputs: RequestPrompt,
482
483
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
484
485
486
487
488
489
490
491
492
493
494
495
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return

        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = None
        elif isinstance(inputs, list):
            prompt = None
            prompt_token_ids = inputs
496
        else:
497
498
499
500
501
502
503
504
505
506
507
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

523
    @staticmethod
524
    def _base_request_id(raw_request: Optional[Request],
525
526
527
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
528
529
530
531
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
532

533
    @staticmethod
534
535
536
537
538
539
540
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

541
542
        if logprob.decoded_token is not None:
            return logprob.decoded_token
543
        return tokenizer.decode(token_id)
544

545
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
546
547
        if not model_name:
            return True
548
        return self.models.is_base_model(model_name)
549
550
551
552
553
554

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
555
        if not model_name:
556
557
            return self.models.base_model_paths[0].name
        return model_name
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs