serving.py 22.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
from http import HTTPStatus
5
from typing import Any, cast
6
7
8
9
10
11
12
13
14
15
16
17
18
19

from openai_harmony import Message as OpenAIMessage

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
    ErrorResponse,
)
20
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
21
22
23
24
25
26
from vllm.entrypoints.openai.parser.harmony_utils import (
    get_developer_message,
    get_system_message,
    parse_chat_inputs_to_harmony_messages,
    render_for_completion,
)
27
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
28
from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item
29
30
31
32
33
34
35
36
37
from vllm.entrypoints.serve.disagg.protocol import (
    GenerateRequest,
    MultiModalFeatures,
    PlaceholderRangeInfo,
)
from vllm.entrypoints.utils import (
    create_error_response,
    get_max_tokens,
)
38
39
40
from vllm.inputs import (
    EngineInput,
    MultiModalHashes,
41
    MultiModalInput,
42
43
44
45
46
    MultiModalPlaceholders,
    PromptType,
    SingletonPrompt,
    tokens_input,
)
47
48
from vllm.logger import init_logger
from vllm.parser import ParserManager
49
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
50
from vllm.renderers import BaseRenderer, merge_kwargs
51
52
53
54
55
56
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
57
from vllm.tool_parsers import ToolParser
58
from vllm.utils import random_uuid
59
from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
60
61
62
63
64
65
66
67
68
69
from vllm.utils.mistral import mt as _mt

logger = init_logger(__name__)


class OpenAIServingRender:
    def __init__(
        self,
        model_config: ModelConfig,
        renderer: BaseRenderer,
70
        model_registry: OpenAIModelRegistry,
71
72
73
74
75
76
77
78
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        enable_auto_tools: bool = False,
        exclude_tools_when_tool_choice_none: bool = False,
        tool_parser: str | None = None,
79
        reasoning_parser: str | None = None,
80
81
82
83
84
        default_chat_template_kwargs: dict[str, Any] | None = None,
        log_error_stack: bool = False,
    ) -> None:
        self.model_config = model_config
        self.renderer = renderer
85
        self.model_registry = model_registry
86
87
88
89
90
91
92
93
        self.request_logger = request_logger
        self.chat_template = chat_template
        self.chat_template_content_format: ChatTemplateContentFormatOption = (
            chat_template_content_format
        )
        self.trust_request_chat_template = trust_request_chat_template
        self.enable_auto_tools = enable_auto_tools
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
94
95
96
97
        self.tool_parser: type[ToolParser] | None = ParserManager.get_tool_parser(
            tool_parser_name=tool_parser,
            enable_auto_tools=enable_auto_tools,
            model_name=model_config.model,
98
        )
99
100
101
102
103
        self.reasoning_parser: type[ReasoningParser] | None = (
            ParserManager.get_reasoning_parser(
                reasoning_parser_name=reasoning_parser,
            )
        )
104
105
106
107
108
109
110
111
        self.default_chat_template_kwargs: dict[str, Any] = (
            default_chat_template_kwargs or {}
        )
        self.log_error_stack = log_error_stack
        self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
        self.supports_browsing = False
        self.supports_code_interpreter = False

112
113
114
115
116
117
118
119
        self.default_sampling_params = model_config.get_diff_sampling_param()
        mc = model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )

120
121
122
    async def render_chat_request(
        self,
        request: ChatCompletionRequest,
123
    ) -> GenerateRequest | ErrorResponse:
124
        """Validate the model and preprocess a chat completion request.
125

126
127
        This is the authoritative implementation used directly by the
        GPU-less render server and delegated to by OpenAIServingChat.
128
129
130
131
132
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            logger.error("Error with model %s", error_check_ret)
            return error_check_ret
133
134
135
136
137
138

        if request.use_beam_search:
            return self.create_error_response(
                "Beam search is not supported by the render endpoint"
            )

139
        result = await self.render_chat(request, skip_mm_cache=True)
140
141
142
        if isinstance(result, ErrorResponse):
            return result

143
        _, engine_inputs = result
144

145
        if len(engine_inputs) != 1:
146
            return self.create_error_response(
147
                f"Expected exactly 1 engine prompt, got {len(engine_inputs)}"
148
149
            )

150
        engine_input = engine_inputs[0]
151

152
        prompt_components = extract_prompt_components(self.model_config, engine_input)
153
154
155
156
157
        token_ids = prompt_components.token_ids
        if not token_ids:
            return self.create_error_response("No token_ids rendered")
        token_ids = list(token_ids)

158
        input_length = extract_prompt_len(self.model_config, engine_input)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        max_tokens = get_max_tokens(
            self.model_config.max_model_len,
            request.max_completion_tokens
            if request.max_completion_tokens is not None
            else request.max_tokens,
            input_length,
            self.default_sampling_params,
            self.override_max_tokens,
        )
        params = request.to_sampling_params(max_tokens, self.default_sampling_params)

        request_id = f"chatcmpl-{random_uuid()}"

        return GenerateRequest(
            request_id=request_id,
            token_ids=token_ids,
175
            features=self._extract_mm_features(engine_input),
176
177
178
179
180
181
182
            sampling_params=params,
            model=request.model,
            stream=bool(request.stream),
            stream_options=(request.stream_options if request.stream else None),
            cache_salt=request.cache_salt,
            priority=request.priority,
        )
183

184
185
186
    async def render_chat(
        self,
        request: ChatCompletionRequest,
187
188
        *,
        skip_mm_cache: bool = False,
189
    ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
190
191
192
193
194
        """Core preprocessing logic for chat requests (no model/engine check).

        Called directly by render_chat_request and delegated to by
        OpenAIServingChat.render_chat_request after its engine-aware checks.
        """
195
        tokenizer = self.renderer.tokenizer
196

197
        tool_parser = self.tool_parser
198

199
200
201
202
203
204
        if is_mistral_tokenizer(tokenizer):
            # because of issues with pydantic we need to potentially
            # re-serialize the tool_calls field of the request
            _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
            _mt.truncate_tool_call_ids(request)  # type: ignore[arg-type]
            _mt.validate_request_params(request)
205

206
207
208
209
210
211
        # Check if tool parsing is unavailable (common condition)
        tool_parsing_unavailable = (
            tool_parser is None
            and not is_mistral_tokenizer(tokenizer)
            and not self.use_harmony
        )
212

213
214
215
216
217
218
219
220
221
222
223
        # Validate tool_choice when tool parsing is required but unavailable
        if tool_parsing_unavailable and request.tool_choice not in (
            None,
            "none",
        ):
            if request.tool_choice == "auto" and not self.enable_auto_tools:
                # for hf tokenizers, "auto" tools requires
                # --enable-auto-tool-choice and --tool-call-parser
                return self.create_error_response(
                    '"auto" tool choice requires '
                    "--enable-auto-tool-choice and --tool-call-parser to be set"
224
                )
225
226
227
228
229
            elif request.tool_choice != "auto":
                # "required" or named tool requires tool parser
                return self.create_error_response(
                    f'tool_choice="{request.tool_choice}" requires '
                    "--tool-call-parser to be set"
230
                )
231
232
233
234
235
236
237
238
239
240

        if request.tools is None or (
            request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
        ):
            tool_dicts = None
        else:
            tool_dicts = [tool.model_dump() for tool in request.tools]

        if not self.use_harmony:
            # Common case.
241
            error_check_ret = self.validate_chat_template(
242
243
244
245
246
247
248
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret

249
            conversation, engine_inputs = await self.preprocess_chat(
250
251
252
253
254
255
256
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=self.default_chat_template_kwargs,
                tool_dicts=tool_dicts,
                tool_parser=tool_parser,
257
                skip_mm_cache=skip_mm_cache,
258
                reasoning_parser=self.reasoning_parser,
259
260
261
262
            )
        else:
            # For GPT-OSS.
            should_include_tools = tool_dicts is not None
263
            conversation, engine_inputs = self._make_request_with_harmony(
264
265
                request, should_include_tools
            )
266

267
        return conversation, engine_inputs
268
269
270
271

    async def render_completion_request(
        self,
        request: CompletionRequest,
272
    ) -> list[GenerateRequest] | ErrorResponse:
273
        """Validate the model and preprocess a completion request.
274

275
276
        This is the authoritative implementation used directly by the
        GPU-less render server and delegated to by OpenAIServingCompletion.
277
278
279
280
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret
281
        result = await self.render_completion(request, skip_mm_cache=True)
282
283
284
        if isinstance(result, ErrorResponse):
            return result
        generate_requests: list[GenerateRequest] = []
285
        for engine_input in result:
286
            prompt_components = extract_prompt_components(
287
                self.model_config, engine_input
288
289
290
291
292
293
            )
            token_ids = prompt_components.token_ids
            if not token_ids:
                return self.create_error_response("No token_ids rendered")
            token_ids = list(token_ids)

294
            input_length = extract_prompt_len(self.model_config, engine_input)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            max_tokens = get_max_tokens(
                self.model_config.max_model_len,
                request.max_tokens,
                input_length,
                self.default_sampling_params,
                self.override_max_tokens,
            )
            params = request.to_sampling_params(
                max_tokens, self.default_sampling_params
            )

            request_id = f"cmpl-{random_uuid()}"

            generate_requests.append(
                GenerateRequest(
                    request_id=request_id,
                    token_ids=token_ids,
312
                    features=self._extract_mm_features(engine_input),
313
314
315
316
317
318
319
320
321
322
                    sampling_params=params,
                    model=request.model,
                    stream=bool(request.stream),
                    stream_options=(request.stream_options if request.stream else None),
                    cache_salt=request.cache_salt,
                    priority=request.priority,
                )
            )

        return generate_requests
323

324
325
326
    async def render_completion(
        self,
        request: CompletionRequest,
327
328
        *,
        skip_mm_cache: bool = False,
329
    ) -> list[EngineInput] | ErrorResponse:
330
331
332
333
334
        """Core preprocessing logic for completion requests (no model/engine check).

        Called directly by render_completion_request and delegated to by
        OpenAIServingCompletion.render_completion_request after its engine-aware checks.
        """
335
336
337
338
339
340
341
342
343
344
345
346
        # Return error for unsupported features.
        if request.suffix is not None:
            return self.create_error_response("suffix is not currently supported")

        if request.echo and request.prompt_embeds is not None:
            return self.create_error_response("Echo is unsupported with prompt embeds.")

        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
            return self.create_error_response(
                "prompt_logprobs is not compatible with prompt embeds."
            )

347
        engine_inputs = await self.preprocess_completion(
348
349
350
            request,
            prompt_input=request.prompt,
            prompt_embeds=request.prompt_embeds,
351
            skip_mm_cache=skip_mm_cache,
352
        )
353

354
        return engine_inputs
355

356
357
    @staticmethod
    def _extract_mm_features(
358
        engine_input: EngineInput,
359
360
361
362
363
    ) -> MultiModalFeatures | None:
        """Extract multimodal metadata from a rendered engine prompt.

        Returns ``None`` for text-only prompts.
        """
364
        if engine_input.get("type") != "multimodal":
365
366
            return None

367
368
369
370
        # At this point engine_input is a MultiModalInput TypedDict.
        mm_engine_input = cast(MultiModalInput, engine_input)
        mm_hashes: MultiModalHashes = mm_engine_input["mm_hashes"]
        raw_placeholders: MultiModalPlaceholders = mm_engine_input["mm_placeholders"]
371
372
373
374
375
376
377
378

        mm_placeholders = {
            modality: [
                PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges
            ]
            for modality, ranges in raw_placeholders.items()
        }

379
380
381
382
383
384
385
386
387
388
        # Serialize tensor data per modality.
        kwargs_data: dict[str, list[str | None]] | None = None
        if raw_mm_kwargs := mm_engine_input.get("mm_kwargs"):
            kwargs_data = {}
            for modality, items in raw_mm_kwargs.items():
                kwargs_data[modality] = [
                    encode_mm_kwargs_item(item) if item is not None else None
                    for item in items
                ]

389
390
391
        return MultiModalFeatures(
            mm_hashes=mm_hashes,
            mm_placeholders=mm_placeholders,
392
            kwargs_data=kwargs_data,
393
394
        )

395
396
397
398
399
    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
        should_include_tools: bool = True,
    ):
400
        """Build Harmony (GPT-OSS) messages and engine prompt from a chat request."""
401
402
403
404
405
406
407
408
409
410
411
412
        messages: list[OpenAIMessage] = []

        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
        _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]

        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
413
414
        if (reasoning_effort := request.reasoning_effort) == "none":
            raise ValueError(f"Harmony does not support {reasoning_effort=}")
415
        sys_msg = get_system_message(
416
            reasoning_effort=reasoning_effort,
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
            browser_description=None,
            python_description=None,
            with_custom_tools=should_include_tools,
        )
        messages.append(sys_msg)

        # Add developer message.
        if request.tools:
            dev_msg = get_developer_message(
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
            )
            messages.append(dev_msg)

        # Add user message.
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
435
        engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
436

437
        return messages, [engine_input]
438
439
440
441
442
443
444
445

    def create_error_response(
        self,
        message: str | Exception,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
        param: str | None = None,
    ) -> ErrorResponse:
446
        return create_error_response(message, err_type, status_code, param)
447
448
449
450
451

    async def _check_model(
        self,
        request: Any,
    ) -> ErrorResponse | None:
452
        return await self.model_registry.check_model(request.model)
453

454
    def validate_chat_template(
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        self,
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
        trust_request_chat_template: bool,
    ) -> ErrorResponse | None:
        """Copied from OpenAIServing._validate_chat_template."""
        if not trust_request_chat_template and (
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
                "Refused request with untrusted chat template."
            )
        return None

475
    async def preprocess_completion(
476
477
478
479
        self,
        request: Any,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
480
481
        *,
        skip_mm_cache: bool = False,
482
    ) -> list[EngineInput]:
483
484
485
486
487
488
        """Copied from OpenAIServing._preprocess_completion."""
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))
489
        return await self.preprocess_cmpl(request, prompts, skip_mm_cache=skip_mm_cache)
490

491
    async def preprocess_cmpl(
492
493
494
        self,
        request: Any,
        prompts: Sequence[PromptType | bytes],
495
496
        *,
        skip_mm_cache: bool = False,
497
    ) -> list[EngineInput]:
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        """Copied from OpenAIServing._preprocess_cmpl."""
        renderer = self.renderer
        model_config = self.model_config

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = request.build_tok_params(model_config)

        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
520
            skip_mm_cache=skip_mm_cache,
521
522
        )

523
    async def preprocess_chat(
524
525
526
527
528
529
530
        self,
        request: Any,
        messages: list[Any],
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
        tool_dicts: list[dict[str, Any]] | None = None,
531
        tool_parser: type[ToolParser] | None = None,
532
        reasoning_parser: type[ReasoningParser] | None = None,
533
534
        *,
        skip_mm_cache: bool = False,
535
    ) -> tuple[list[ConversationMessage], list[EngineInput]]:
536
        """Copied from OpenAIServing._preprocess_chat."""
537
        renderer = self.renderer
538
        mm_config = self.model_config.multimodal_config
539
540
541
542
543
544
545
546
547
548
549
550

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
            ),
        )

        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
551
552
553
554
555
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
            default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
        )
556

557
        (conversation,), (engine_input,) = await renderer.render_chat_async(
558
559
560
561
562
563
564
565
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
566
            skip_mm_cache=skip_mm_cache,
567
568
        )

569
570
        if reasoning_parser is not None:
            tokenizer = renderer.get_tokenizer()
571
            request = reasoning_parser(
572
573
574
                tokenizer,
                model_config=self.model_config,
                chat_template_kwargs=chat_params.chat_template_kwargs,
575
            ).adjust_request(request=request)
576

577
578
579
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
580
581
582
583
        #
        # Exception: Mistral grammar-capable tokenizers always call
        # adjust_request — even for tool_choice="none" — so that the grammar
        # factory can prevent special-token leakage.
584
585
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
586
587
            tokenizer = renderer.get_tokenizer()
            is_mistral_grammar_eligible = (
588
                is_mistral_tool_parser(tool_parser)
589
590
591
592
                and is_mistral_tokenizer(tokenizer)
                and tokenizer.supports_grammar
            )
            if tool_choice != "none" or is_mistral_grammar_eligible:
593
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
594
595
                    msg = (
                        "Tool usage is only supported "
596
597
                        "for Chat Completions API or Responses API requests, "
                        f"but got {type(request).__name__}"
598
599
                    )
                    raise NotImplementedError(msg)
600
                request = tool_parser(tokenizer, request.tools).adjust_request(
601
                    request=request
602
                )
603

604
        return conversation, [engine_input]