serving_engine.py 21.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
from collections.abc import Iterable, Iterator, Mapping, Sequence
5
from concurrent.futures.thread import ThreadPoolExecutor
6
from http import HTTPStatus
7
from typing import Annotated, Any, Callable, Optional, TypedDict, Union
8

9
from fastapi import Request
10
from pydantic import Field
11
from starlette.datastructures import Headers
12

13
from vllm.config import ModelConfig
14
from vllm.engine.protocol import EngineClient
15
16
# yapf conflicts with isort for this block
# yapf: disable
17
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
18
                                         ChatTemplateContentFormatOption,
19
20
21
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
22
23
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
26
                                              CompletionRequest,
27
                                              DetokenizeRequest,
28
29
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
30
31
                                              ErrorResponse, RerankRequest,
                                              ScoreRequest,
32
                                              TokenizeChatRequest,
33
34
                                              TokenizeCompletionRequest,
                                              TranscriptionRequest)
35
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
36
from vllm.entrypoints.openai.tool_parsers import ToolParser
37
# yapf: enable
38
from vllm.inputs import TokensPrompt
39
from vllm.inputs.parse import parse_and_batch_prompt
40
from vllm.logger import init_logger
41
from vllm.lora.request import LoRARequest
42
from vllm.pooling_params import PoolingParams
43
from vllm.prompt_adapter.request import PromptAdapterRequest
44
from vllm.sampling_params import BeamSearchParams, SamplingParams
45
from vllm.sequence import Logprob, PromptLogprobs
46
47
48
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
49
from vllm.utils import is_list_of, make_async, random_uuid
50
51
52

logger = init_logger(__name__)

53
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
54
55
                              EmbeddingCompletionRequest, RerankRequest,
                              ScoreRequest, TokenizeCompletionRequest]
56
57
58
59

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]

60
61
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
                   TranscriptionRequest]
62
63
64
65


class TextTokensPrompt(TypedDict):
    prompt: str
66
    prompt_token_ids: list[int]
67
68


69
RequestPrompt = Union[list[int], str, TextTokensPrompt]
70
71


72
73
class OpenAIServing:

74
75
    def __init__(
        self,
76
        engine_client: EngineClient,
77
        model_config: ModelConfig,
78
        models: OpenAIServingModels,
79
80
        *,
        request_logger: Optional[RequestLogger],
81
        return_tokens_as_token_ids: bool = False,
82
    ):
83
84
        super().__init__()

85
        self.engine_client = engine_client
86
        self.model_config = model_config
87
88
        self.max_model_len = model_config.max_model_len

89
        self.models = models
90

91
        self.request_logger = request_logger
92
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
93

94
95
96
97
98
99
100
101
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

        self._tokenize_prompt_input_async = make_async(
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
        self._tokenize_prompt_input_or_inputs_async = make_async(
            self._tokenize_prompt_input_or_inputs,
            executor=self._tokenizer_executor)

102
103
104
105
106
107
108
109
110
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

111
112
113
114
115
116
117
118
119
120
121
122
123
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

124
    async def _check_model(
125
126
        self,
        request: AnyRequest,
127
    ) -> Optional[ErrorResponse]:
128
        if self._is_model_supported(request.model):
129
            return None
130
131
132
        if request.model in [
                lora.lora_name for lora in self.models.lora_requests
        ]:
133
            return None
134
135
        if request.model in [
                prompt_adapter.prompt_adapter_name
136
                for prompt_adapter in self.models.prompt_adapter_requests
137
138
        ]:
            return None
139
140
141
142
143
        return self.create_error_response(
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

144
145
    def _maybe_get_adapters(
        self, request: AnyRequest
146
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
147
            None, PromptAdapterRequest]]:
148
        if self._is_model_supported(request.model):
149
            return None, None
150
        for lora in self.models.lora_requests:
151
            if request.model == lora.lora_name:
152
                return lora, None
153
        for prompt_adapter in self.models.prompt_adapter_requests:
154
            if request.model == prompt_adapter.prompt_adapter_name:
155
                return None, prompt_adapter
156
        # if _check_model has been called earlier, this will be unreachable
157
        raise ValueError(f"The model `{request.model}` does not exist.")
158

159
160
161
162
163
164
165
166
    def _normalize_prompt_text_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
167
168
169
170
171
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

172
173
        if truncate_prompt_tokens is None:
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
174
        else:
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=truncate_prompt_tokens)

        input_ids = encoded.input_ids

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

    def _normalize_prompt_tokens_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
190
        prompt_ids: list[int],
191
192
193
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
        if truncate_prompt_tokens is None:
194
            input_ids = prompt_ids
195
196
197
198
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

        input_text = tokenizer.decode(input_ids)
199

200
201
202
203
204
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
205
        input_ids: list[int],
206
207
        input_text: str,
    ) -> TextTokensPrompt:
208
209
        token_num = len(input_ids)

210
        # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
211
212
213
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
                       ScoreRequest, RerankRequest)):
214
215
216

            operation = "score" if isinstance(request, ScoreRequest) \
                else "embedding generation"
217
218
219
220
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
221
222
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
223
224
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
225

226
227
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
228
229
230
231
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
232

233
234
235
236
237
238
239
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
            max_tokens = request.max_tokens
        if max_tokens is None:
240
241
242
243
244
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
245
                    f"Please reduce the length of the messages.")
246
        elif token_num + max_tokens > self.max_model_len:
247
            raise ValueError(
248
249
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
250
                f"{max_tokens + token_num} tokens "
251
                f"({token_num} in the messages, "
252
                f"{max_tokens} in the completion). "
253
254
255
256
257
258
259
260
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

    def _tokenize_prompt_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
261
        prompt_input: Union[str, list[int]],
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes single input.
        """
        return next(
            self._tokenize_prompt_inputs(
                request,
                tokenizer,
                [prompt_input],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            ))

    def _tokenize_prompt_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
282
        prompt_inputs: Iterable[Union[str, list[int]]],
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> Iterator[TextTokensPrompt]:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
                yield self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
                yield self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

    def _tokenize_prompt_input_or_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
311
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
312
313
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
314
    ) -> list[TextTokensPrompt]:
315
316
317
318
319
320
321
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
        # "is True" is required for Pyright to perform type narrowing
        # See: https://github.com/microsoft/pyright/issues/7672
        return [
            self._normalize_prompt_text_to_input(
                request,
                tokenizer,
                prompt=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens)
            if prompt_input["is_tokens"] is False else
            self._normalize_prompt_tokens_to_input(
                request,
                tokenizer,
                prompt_ids=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens)
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
        ]
341

342
    async def _preprocess_completion(
343
344
345
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
346
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
347
348
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
349
    ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
350
351
352
353
354
355
356
        request_prompts = await self._tokenize_prompt_input_or_inputs_async(
            request,
            tokenizer,
            input_or_inputs,
            truncate_prompt_tokens=truncate_prompt_tokens,
            add_special_tokens=add_special_tokens,
        )
357
358
359
360
361
362
363
364
365
366
367
368

        engine_prompts = [
            TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
            for request_prompt in request_prompts
        ]

        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
        request: ChatLikeRequest,
        tokenizer: AnyTokenizer,
369
        messages: list[ChatCompletionMessageParam],
370
371
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
372
373
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
374
375
376
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
377
378
379
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
380
381
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
               list[TokensPrompt]]:
382
383
384
385
386
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )
387
388
389
390
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
            self.model_config,
            tokenizer,
391
            content_format=resolved_content_format,
392
393
        )

394
        _chat_template_kwargs: dict[str, Any] = dict(
395
396
397
398
399
400
401
402
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

403
        request_prompt: Union[str, list[int]]
404
        if isinstance(tokenizer, MistralTokenizer):
405
406
407
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
408
                **_chat_template_kwargs,
409
410
411
412
413
            )
        else:
            request_prompt = apply_hf_chat_template(
                tokenizer,
                conversation=conversation,
414
                **_chat_template_kwargs,
415
416
417
418
            )

        mm_data = await mm_data_future

419
420
421
422
423
424
425
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
426
427
428
429
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

430
431
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
432
433

        if isinstance(request_prompt, str):
434
            prompt_inputs = await self._tokenize_prompt_input_async(
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

        engine_prompt = TokensPrompt(
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
453
454
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
455
456
457

        return conversation, [request_prompt], [engine_prompt]

458
459
460
    def _log_inputs(
        self,
        request_id: str,
461
        inputs: RequestPrompt,
462
463
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
464
465
466
467
468
469
470
471
472
473
474
475
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return

        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = None
        elif isinstance(inputs, list):
            prompt = None
            prompt_token_ids = inputs
476
        else:
477
478
479
480
481
482
483
484
485
486
487
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

503
    @staticmethod
504
    def _base_request_id(raw_request: Optional[Request],
505
506
507
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
508
509
510
511
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
512

513
    @staticmethod
514
515
516
517
518
519
520
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

521
522
        if logprob.decoded_token is not None:
            return logprob.decoded_token
523
        return tokenizer.decode(token_id)
524

525
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
526
527
        if not model_name:
            return True
528
        return self.models.is_base_model(model_name)
529
530
531
532
533
534
535
536
537

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
        if model_name is None:
            return self.models.base_model_paths[0].name
        return model_name
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs