serving_engine.py 21.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
from collections.abc import Iterable, Iterator, Mapping, Sequence
5
from concurrent.futures.thread import ThreadPoolExecutor
6
from http import HTTPStatus
7
from typing import Annotated, Any, Callable, Optional, TypedDict, Union
8

9
from fastapi import Request
10
from pydantic import Field
11
from starlette.datastructures import Headers
12

13
import vllm.envs as envs
14
from vllm.config import ModelConfig
15
from vllm.engine.protocol import EngineClient
16
17
# yapf conflicts with isort for this block
# yapf: disable
18
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
19
                                         ChatTemplateContentFormatOption,
20
21
22
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
23
24
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
27
                                              CompletionRequest,
28
                                              DetokenizeRequest,
29
30
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
31
32
                                              ErrorResponse, RerankRequest,
                                              ScoreRequest,
33
                                              TokenizeChatRequest,
34
35
                                              TokenizeCompletionRequest,
                                              TranscriptionRequest)
36
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
37
from vllm.entrypoints.openai.tool_parsers import ToolParser
38
# yapf: enable
39
from vllm.inputs import TokensPrompt
40
from vllm.inputs.parse import parse_and_batch_prompt
41
from vllm.logger import init_logger
42
from vllm.lora.request import LoRARequest
43
from vllm.pooling_params import PoolingParams
44
from vllm.prompt_adapter.request import PromptAdapterRequest
45
from vllm.sampling_params import BeamSearchParams, SamplingParams
46
from vllm.sequence import Logprob, PromptLogprobs
47
48
49
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
50
from vllm.utils import is_list_of, make_async, random_uuid
51
52
53

logger = init_logger(__name__)

54
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
55
56
                              EmbeddingCompletionRequest, RerankRequest,
                              ScoreRequest, TokenizeCompletionRequest]
57
58
59
60

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]

61
62
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
                   TranscriptionRequest]
63
64
65
66


class TextTokensPrompt(TypedDict):
    prompt: str
67
    prompt_token_ids: list[int]
68
69


70
RequestPrompt = Union[list[int], str, TextTokensPrompt]
71
72


73
74
class OpenAIServing:

75
76
    def __init__(
        self,
77
        engine_client: EngineClient,
78
        model_config: ModelConfig,
79
        models: OpenAIServingModels,
80
81
        *,
        request_logger: Optional[RequestLogger],
82
        return_tokens_as_token_ids: bool = False,
83
    ):
84
85
        super().__init__()

86
        self.engine_client = engine_client
87
        self.model_config = model_config
88
89
        self.max_model_len = model_config.max_model_len

90
        self.models = models
91

92
        self.request_logger = request_logger
93
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
94

95
96
97
98
99
100
101
102
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

        self._tokenize_prompt_input_async = make_async(
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
        self._tokenize_prompt_input_or_inputs_async = make_async(
            self._tokenize_prompt_input_or_inputs,
            executor=self._tokenizer_executor)

103
104
105
106
107
108
109
110
111
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

112
113
114
115
116
117
118
119
120
121
122
123
124
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

125
    async def _check_model(
126
127
        self,
        request: AnyRequest,
128
    ) -> Optional[ErrorResponse]:
129
130
131

        error_response = None

132
        if self._is_model_supported(request.model):
133
            return None
134
135
136
        if request.model in [
                lora.lora_name for lora in self.models.lora_requests
        ]:
137
            return None
138
139
140
141
142
143
144
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
                load_result := await self.models.resolve_lora(request.model)):
            if isinstance(load_result, LoRARequest):
                return None
            if isinstance(load_result, ErrorResponse) and \
                load_result.code == HTTPStatus.BAD_REQUEST.value:
                error_response = load_result
145
146
        if request.model in [
                prompt_adapter.prompt_adapter_name
147
                for prompt_adapter in self.models.prompt_adapter_requests
148
149
        ]:
            return None
150
151

        return error_response or self.create_error_response(
152
153
154
155
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

156
157
    def _maybe_get_adapters(
        self, request: AnyRequest
158
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
159
            None, PromptAdapterRequest]]:
160
        if self._is_model_supported(request.model):
161
            return None, None
162
        for lora in self.models.lora_requests:
163
            if request.model == lora.lora_name:
164
                return lora, None
165
        for prompt_adapter in self.models.prompt_adapter_requests:
166
            if request.model == prompt_adapter.prompt_adapter_name:
167
                return None, prompt_adapter
168
        # if _check_model has been called earlier, this will be unreachable
169
        raise ValueError(f"The model `{request.model}` does not exist.")
170

171
172
173
174
175
176
177
178
    def _normalize_prompt_text_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
179
180
181
182
183
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

184
185
        if truncate_prompt_tokens is None:
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
186
        else:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=truncate_prompt_tokens)

        input_ids = encoded.input_ids

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

    def _normalize_prompt_tokens_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
202
        prompt_ids: list[int],
203
204
205
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
        if truncate_prompt_tokens is None:
206
            input_ids = prompt_ids
207
208
209
210
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

        input_text = tokenizer.decode(input_ids)
211

212
213
214
215
216
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
217
        input_ids: list[int],
218
219
        input_text: str,
    ) -> TextTokensPrompt:
220
221
        token_num = len(input_ids)

222
        # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
223
224
225
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
                       ScoreRequest, RerankRequest)):
226
227
228

            operation = "score" if isinstance(request, ScoreRequest) \
                else "embedding generation"
229
230
231
232
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
233
234
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
235
236
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
237

238
239
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
240
241
242
243
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
244

245
246
247
248
249
250
251
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
            max_tokens = request.max_tokens
        if max_tokens is None:
252
253
254
255
256
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
257
                    f"Please reduce the length of the messages.")
258
        elif token_num + max_tokens > self.max_model_len:
259
            raise ValueError(
260
261
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
262
                f"{max_tokens + token_num} tokens "
263
                f"({token_num} in the messages, "
264
                f"{max_tokens} in the completion). "
265
266
267
268
269
270
271
272
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

    def _tokenize_prompt_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
273
        prompt_input: Union[str, list[int]],
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes single input.
        """
        return next(
            self._tokenize_prompt_inputs(
                request,
                tokenizer,
                [prompt_input],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            ))

    def _tokenize_prompt_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
294
        prompt_inputs: Iterable[Union[str, list[int]]],
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> Iterator[TextTokensPrompt]:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
                yield self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
                yield self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

    def _tokenize_prompt_input_or_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
323
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
324
325
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
326
    ) -> list[TextTokensPrompt]:
327
328
329
330
331
332
333
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
        # "is True" is required for Pyright to perform type narrowing
        # See: https://github.com/microsoft/pyright/issues/7672
        return [
            self._normalize_prompt_text_to_input(
                request,
                tokenizer,
                prompt=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens)
            if prompt_input["is_tokens"] is False else
            self._normalize_prompt_tokens_to_input(
                request,
                tokenizer,
                prompt_ids=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens)
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
        ]
353

354
    async def _preprocess_completion(
355
356
357
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
358
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
359
360
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
361
    ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
362
363
364
365
366
367
368
        request_prompts = await self._tokenize_prompt_input_or_inputs_async(
            request,
            tokenizer,
            input_or_inputs,
            truncate_prompt_tokens=truncate_prompt_tokens,
            add_special_tokens=add_special_tokens,
        )
369
370
371
372
373
374
375
376
377
378
379
380

        engine_prompts = [
            TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
            for request_prompt in request_prompts
        ]

        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
        request: ChatLikeRequest,
        tokenizer: AnyTokenizer,
381
        messages: list[ChatCompletionMessageParam],
382
383
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
384
385
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
386
387
388
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
389
390
391
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
392
393
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
               list[TokensPrompt]]:
394
395
        model_config = self.model_config

396
397
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
398
            tool_dicts,
399
400
            chat_template_content_format,
            tokenizer,
401
            trust_remote_code=model_config.trust_remote_code,
402
        )
403
404
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
405
            model_config,
406
            tokenizer,
407
            content_format=resolved_content_format,
408
409
        )

410
        _chat_template_kwargs: dict[str, Any] = dict(
411
412
413
414
415
416
417
418
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

419
        request_prompt: Union[str, list[int]]
420
        if isinstance(tokenizer, MistralTokenizer):
421
422
423
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
424
                **_chat_template_kwargs,
425
426
427
428
            )
        else:
            request_prompt = apply_hf_chat_template(
                tokenizer,
429
                trust_remote_code=model_config.trust_remote_code,
430
                conversation=conversation,
431
                **_chat_template_kwargs,
432
433
434
435
            )

        mm_data = await mm_data_future

436
437
438
439
440
441
442
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
443
444
445
446
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

447
448
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
449
450

        if isinstance(request_prompt, str):
451
            prompt_inputs = await self._tokenize_prompt_input_async(
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

        engine_prompt = TokensPrompt(
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
470
471
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
472
473
474

        return conversation, [request_prompt], [engine_prompt]

475
476
477
    def _log_inputs(
        self,
        request_id: str,
478
        inputs: RequestPrompt,
479
480
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
481
482
483
484
485
486
487
488
489
490
491
492
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return

        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = None
        elif isinstance(inputs, list):
            prompt = None
            prompt_token_ids = inputs
493
        else:
494
495
496
497
498
499
500
501
502
503
504
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
505

506
507
508
509
510
511
512
513
514
515
516
517
518
519
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

520
    @staticmethod
521
    def _base_request_id(raw_request: Optional[Request],
522
523
524
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
525
526
527
528
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
529

530
    @staticmethod
531
532
533
534
535
536
537
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

538
539
        if logprob.decoded_token is not None:
            return logprob.decoded_token
540
        return tokenizer.decode(token_id)
541

542
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
543
544
        if not model_name:
            return True
545
        return self.models.is_base_model(model_name)
546
547
548
549
550
551

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
552
        if not model_name:
553
554
            return self.models.base_model_paths[0].name
        return model_name
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs