serving_engine.py 20.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
from concurrent.futures.thread import ThreadPoolExecutor
5
from http import HTTPStatus
6
7
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
                    Optional, Sequence, Tuple, TypedDict, Union)
8

9
from fastapi import Request
10
from pydantic import Field
11
from starlette.datastructures import Headers
12
from typing_extensions import Annotated
13

14
from vllm.config import ModelConfig
15
from vllm.engine.protocol import EngineClient
16
17
# yapf conflicts with isort for this block
# yapf: disable
18
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
19
                                         ChatTemplateContentFormatOption,
20
21
22
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
23
24
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
27
                                              CompletionRequest,
28
                                              DetokenizeRequest,
29
30
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
31
32
                                              ErrorResponse, RerankRequest,
                                              ScoreRequest,
33
                                              TokenizeChatRequest,
34
35
                                              TokenizeCompletionRequest,
                                              TranscriptionRequest)
36
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
37
from vllm.entrypoints.openai.tool_parsers import ToolParser
38
# yapf: enable
39
from vllm.inputs import TokensPrompt
40
from vllm.inputs.parse import parse_and_batch_prompt
41
from vllm.logger import init_logger
42
from vllm.lora.request import LoRARequest
43
from vllm.pooling_params import PoolingParams
44
from vllm.prompt_adapter.request import PromptAdapterRequest
45
from vllm.sampling_params import BeamSearchParams, SamplingParams
46
from vllm.sequence import Logprob
47
48
49
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
50
from vllm.utils import is_list_of, make_async, random_uuid
51
52
53

logger = init_logger(__name__)

54
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
55
                              EmbeddingCompletionRequest, ScoreRequest,
56
57
58
59
60
                              TokenizeCompletionRequest]

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]

61
62
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
                   TranscriptionRequest]
63
64
65
66
67
68
69


class TextTokensPrompt(TypedDict):
    prompt: str
    prompt_token_ids: List[int]


70
71
72
RequestPrompt = Union[List[int], str, TextTokensPrompt]


73
74
class OpenAIServing:

75
76
    def __init__(
        self,
77
        engine_client: EngineClient,
78
        model_config: ModelConfig,
79
        models: OpenAIServingModels,
80
81
        *,
        request_logger: Optional[RequestLogger],
82
        return_tokens_as_token_ids: bool = False,
83
    ):
84
85
        super().__init__()

86
        self.engine_client = engine_client
87
        self.model_config = model_config
88
89
        self.max_model_len = model_config.max_model_len

90
        self.models = models
91

92
        self.request_logger = request_logger
93
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
94

95
96
97
98
99
100
101
102
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

        self._tokenize_prompt_input_async = make_async(
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
        self._tokenize_prompt_input_or_inputs_async = make_async(
            self._tokenize_prompt_input_or_inputs,
            executor=self._tokenizer_executor)

103
104
105
106
107
108
109
110
111
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

112
113
114
115
116
117
118
119
120
121
122
123
124
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

125
    async def _check_model(
126
127
        self,
        request: AnyRequest,
128
    ) -> Optional[ErrorResponse]:
129
        if self._is_model_supported(request.model):
130
            return None
131
132
133
        if request.model in [
                lora.lora_name for lora in self.models.lora_requests
        ]:
134
            return None
135
136
        if request.model in [
                prompt_adapter.prompt_adapter_name
137
                for prompt_adapter in self.models.prompt_adapter_requests
138
139
        ]:
            return None
140
141
142
143
144
        return self.create_error_response(
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

145
146
147
148
    def _maybe_get_adapters(
        self, request: AnyRequest
    ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
            None, PromptAdapterRequest]]:
149
        if self._is_model_supported(request.model):
150
            return None, None
151
        for lora in self.models.lora_requests:
152
            if request.model == lora.lora_name:
153
                return lora, None
154
        for prompt_adapter in self.models.prompt_adapter_requests:
155
            if request.model == prompt_adapter.prompt_adapter_name:
156
                return None, prompt_adapter
157
        # if _check_model has been called earlier, this will be unreachable
158
        raise ValueError(f"The model `{request.model}` does not exist.")
159

160
161
162
163
164
165
166
167
    def _normalize_prompt_text_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
168
169
170
171
172
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

173
174
        if truncate_prompt_tokens is None:
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
175
        else:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=truncate_prompt_tokens)

        input_ids = encoded.input_ids

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

    def _normalize_prompt_tokens_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt_ids: List[int],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
        if truncate_prompt_tokens is None:
195
            input_ids = prompt_ids
196
197
198
199
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

        input_text = tokenizer.decode(input_ids)
200

201
202
203
204
205
206
207
208
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
        input_ids: List[int],
        input_text: str,
    ) -> TextTokensPrompt:
209
210
        token_num = len(input_ids)

211
        # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
212
213
214
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
                       ScoreRequest, RerankRequest)):
215
216
217

            operation = "score" if isinstance(request, ScoreRequest) \
                else "embedding generation"
218
219
220
221
            if token_num > self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
222
223
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
224
225
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
226

227
228
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
229
230
231
232
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
233

234
235
236
237
238
239
240
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
            max_tokens = request.max_tokens
        if max_tokens is None:
241
242
243
244
245
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
246
                    f"Please reduce the length of the messages.")
247
        elif token_num + max_tokens > self.max_model_len:
248
            raise ValueError(
249
250
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
251
                f"{max_tokens + token_num} tokens "
252
                f"({token_num} in the messages, "
253
                f"{max_tokens} in the completion). "
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

    def _tokenize_prompt_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt_input: Union[str, List[int]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes single input.
        """
        return next(
            self._tokenize_prompt_inputs(
                request,
                tokenizer,
                [prompt_input],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            ))

    def _tokenize_prompt_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt_inputs: Iterable[Union[str, List[int]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
    ) -> Iterator[TextTokensPrompt]:
        """
        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
                yield self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
                yield self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

    def _tokenize_prompt_input_or_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
315
    ) -> List[TextTokensPrompt]:
316
317
318
319
320
321
322
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
        # "is True" is required for Pyright to perform type narrowing
        # See: https://github.com/microsoft/pyright/issues/7672
        return [
            self._normalize_prompt_text_to_input(
                request,
                tokenizer,
                prompt=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens)
            if prompt_input["is_tokens"] is False else
            self._normalize_prompt_tokens_to_input(
                request,
                tokenizer,
                prompt_ids=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens)
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
        ]
342

343
    async def _preprocess_completion(
344
345
346
347
348
349
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = True,
350
351
352
353
354
355
356
357
    ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
        request_prompts = await self._tokenize_prompt_input_or_inputs_async(
            request,
            tokenizer,
            input_or_inputs,
            truncate_prompt_tokens=truncate_prompt_tokens,
            add_special_tokens=add_special_tokens,
        )
358
359
360
361
362
363
364
365
366
367
368
369
370

        engine_prompts = [
            TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
            for request_prompt in request_prompts
        ]

        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
        request: ChatLikeRequest,
        tokenizer: AnyTokenizer,
        messages: List[ChatCompletionMessageParam],
371
372
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
373
374
375
376
377
378
379
380
381
382
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tool_dicts: Optional[List[Dict[str, Any]]] = None,
        documents: Optional[List[Dict[str, str]]] = None,
        chat_template_kwargs: Optional[Dict[str, Any]] = None,
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
    ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
               List[TokensPrompt]]:
383
384
385
386
387
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )
388
389
390
391
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
            self.model_config,
            tokenizer,
392
            content_format=resolved_content_format,
393
394
        )

395
396
397
398
399
400
401
402
403
        _chat_template_kwargs: Dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

404
        request_prompt: Union[str, List[int]]
405
        if isinstance(tokenizer, MistralTokenizer):
406
407
408
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
409
                **_chat_template_kwargs,
410
411
412
413
414
            )
        else:
            request_prompt = apply_hf_chat_template(
                tokenizer,
                conversation=conversation,
415
                **_chat_template_kwargs,
416
417
418
419
            )

        mm_data = await mm_data_future

420
421
422
423
424
425
426
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
427
428
429
430
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

431
432
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
433
434

        if isinstance(request_prompt, str):
435
            prompt_inputs = await self._tokenize_prompt_input_async(
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

        engine_prompt = TokensPrompt(
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data

        return conversation, [request_prompt], [engine_prompt]

457
458
459
    def _log_inputs(
        self,
        request_id: str,
460
        inputs: RequestPrompt,
461
462
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
463
464
465
466
467
468
469
470
471
472
473
474
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return

        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = None
        elif isinstance(inputs, list):
            prompt = None
            prompt_token_ids = inputs
475
        else:
476
477
478
479
480
481
482
483
484
485
486
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

502
    @staticmethod
503
    def _base_request_id(raw_request: Optional[Request],
504
505
506
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
507
508
509
510
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
511

512
    @staticmethod
513
514
515
516
517
518
519
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

520
521
        if logprob.decoded_token is not None:
            return logprob.decoded_token
522
        return tokenizer.decode(token_id)
523

524
    def _is_model_supported(self, model_name):
525
        return self.models.is_base_model(model_name)