serving.py 21.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from http import HTTPStatus
from typing import Any

from openai_harmony import Message as OpenAIMessage

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
    ErrorResponse,
)
20
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
21
22
23
24
25
26
from vllm.entrypoints.openai.parser.harmony_utils import (
    get_developer_message,
    get_system_message,
    parse_chat_inputs_to_harmony_messages,
    render_for_completion,
)
27
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
28
29
30
31
32
33
34
35
36
from vllm.entrypoints.serve.disagg.protocol import (
    GenerateRequest,
    MultiModalFeatures,
    PlaceholderRangeInfo,
)
from vllm.entrypoints.utils import (
    create_error_response,
    get_max_tokens,
)
37
38
39
40
41
42
43
44
from vllm.inputs import (
    EngineInput,
    MultiModalHashes,
    MultiModalPlaceholders,
    PromptType,
    SingletonPrompt,
    tokens_input,
)
45
46
from vllm.logger import init_logger
from vllm.parser import ParserManager
47
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
48
from vllm.renderers import BaseRenderer, merge_kwargs
49
50
51
52
53
54
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
55
from vllm.tool_parsers import ToolParser
56
from vllm.utils import random_uuid
57
58
59
60
61
62
63
64
65
66
67
68
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt

logger = init_logger(__name__)


class OpenAIServingRender:
    def __init__(
        self,
        model_config: ModelConfig,
        renderer: BaseRenderer,
        io_processor: Any,
69
        model_registry: OpenAIModelRegistry,
70
71
72
73
74
75
76
77
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        enable_auto_tools: bool = False,
        exclude_tools_when_tool_choice_none: bool = False,
        tool_parser: str | None = None,
78
        reasoning_parser: str | None = None,
79
80
81
82
83
84
        default_chat_template_kwargs: dict[str, Any] | None = None,
        log_error_stack: bool = False,
    ) -> None:
        self.model_config = model_config
        self.renderer = renderer
        self.io_processor = io_processor
85
        self.model_registry = model_registry
86
87
88
89
90
91
92
93
        self.request_logger = request_logger
        self.chat_template = chat_template
        self.chat_template_content_format: ChatTemplateContentFormatOption = (
            chat_template_content_format
        )
        self.trust_request_chat_template = trust_request_chat_template
        self.enable_auto_tools = enable_auto_tools
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
94
95
96
97
        self.tool_parser: type[ToolParser] | None = ParserManager.get_tool_parser(
            tool_parser_name=tool_parser,
            enable_auto_tools=enable_auto_tools,
            model_name=model_config.model,
98
        )
99
100
101
102
103
        self.reasoning_parser: type[ReasoningParser] | None = (
            ParserManager.get_reasoning_parser(
                reasoning_parser_name=reasoning_parser,
            )
        )
104
105
106
107
108
109
110
111
        self.default_chat_template_kwargs: dict[str, Any] = (
            default_chat_template_kwargs or {}
        )
        self.log_error_stack = log_error_stack
        self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
        self.supports_browsing = False
        self.supports_code_interpreter = False

112
113
114
115
116
117
118
119
        self.default_sampling_params = model_config.get_diff_sampling_param()
        mc = model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )

120
121
122
    async def render_chat_request(
        self,
        request: ChatCompletionRequest,
123
    ) -> GenerateRequest | ErrorResponse:
124
        """Validate the model and preprocess a chat completion request.
125

126
127
        This is the authoritative implementation used directly by the
        GPU-less render server and delegated to by OpenAIServingChat.
128
129
130
131
132
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            logger.error("Error with model %s", error_check_ret)
            return error_check_ret
133
134
135
136
137
138
139
140
141
142

        if request.use_beam_search:
            return self.create_error_response(
                "Beam search is not supported by the render endpoint"
            )

        result = await self.render_chat(request)
        if isinstance(result, ErrorResponse):
            return result

143
        _, engine_inputs = result
144

145
        if len(engine_inputs) != 1:
146
            return self.create_error_response(
147
                f"Expected exactly 1 engine prompt, got {len(engine_inputs)}"
148
149
            )

150
        engine_input = engine_inputs[0]
151

152
        prompt_components = extract_prompt_components(self.model_config, engine_input)
153
154
155
156
157
        token_ids = prompt_components.token_ids
        if not token_ids:
            return self.create_error_response("No token_ids rendered")
        token_ids = list(token_ids)

158
        input_length = extract_prompt_len(self.model_config, engine_input)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        max_tokens = get_max_tokens(
            self.model_config.max_model_len,
            request.max_completion_tokens
            if request.max_completion_tokens is not None
            else request.max_tokens,
            input_length,
            self.default_sampling_params,
            self.override_max_tokens,
        )
        params = request.to_sampling_params(max_tokens, self.default_sampling_params)

        request_id = f"chatcmpl-{random_uuid()}"

        return GenerateRequest(
            request_id=request_id,
            token_ids=token_ids,
175
            features=self._extract_mm_features(engine_input),
176
177
178
179
180
181
182
            sampling_params=params,
            model=request.model,
            stream=bool(request.stream),
            stream_options=(request.stream_options if request.stream else None),
            cache_salt=request.cache_salt,
            priority=request.priority,
        )
183

184
185
186
    async def render_chat(
        self,
        request: ChatCompletionRequest,
187
    ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
188
189
190
191
192
        """Core preprocessing logic for chat requests (no model/engine check).

        Called directly by render_chat_request and delegated to by
        OpenAIServingChat.render_chat_request after its engine-aware checks.
        """
193
        tokenizer = self.renderer.tokenizer
194

195
        tool_parser = self.tool_parser
196

197
198
199
200
201
202
        if is_mistral_tokenizer(tokenizer):
            # because of issues with pydantic we need to potentially
            # re-serialize the tool_calls field of the request
            _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
            _mt.truncate_tool_call_ids(request)  # type: ignore[arg-type]
            _mt.validate_request_params(request)
203

204
205
206
207
208
209
        # Check if tool parsing is unavailable (common condition)
        tool_parsing_unavailable = (
            tool_parser is None
            and not is_mistral_tokenizer(tokenizer)
            and not self.use_harmony
        )
210

211
212
213
214
215
216
217
218
219
220
221
        # Validate tool_choice when tool parsing is required but unavailable
        if tool_parsing_unavailable and request.tool_choice not in (
            None,
            "none",
        ):
            if request.tool_choice == "auto" and not self.enable_auto_tools:
                # for hf tokenizers, "auto" tools requires
                # --enable-auto-tool-choice and --tool-call-parser
                return self.create_error_response(
                    '"auto" tool choice requires '
                    "--enable-auto-tool-choice and --tool-call-parser to be set"
222
                )
223
224
225
226
227
            elif request.tool_choice != "auto":
                # "required" or named tool requires tool parser
                return self.create_error_response(
                    f'tool_choice="{request.tool_choice}" requires '
                    "--tool-call-parser to be set"
228
                )
229
230
231
232
233
234
235
236
237
238

        if request.tools is None or (
            request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
        ):
            tool_dicts = None
        else:
            tool_dicts = [tool.model_dump() for tool in request.tools]

        if not self.use_harmony:
            # Common case.
239
            error_check_ret = self.validate_chat_template(
240
241
242
243
244
245
246
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret

247
            conversation, engine_inputs = await self.preprocess_chat(
248
249
250
251
252
253
254
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=self.default_chat_template_kwargs,
                tool_dicts=tool_dicts,
                tool_parser=tool_parser,
255
                reasoning_parser=self.reasoning_parser,
256
257
258
259
            )
        else:
            # For GPT-OSS.
            should_include_tools = tool_dicts is not None
260
            conversation, engine_inputs = self._make_request_with_harmony(
261
262
                request, should_include_tools
            )
263

264
        return conversation, engine_inputs
265
266
267
268

    async def render_completion_request(
        self,
        request: CompletionRequest,
269
    ) -> list[GenerateRequest] | ErrorResponse:
270
        """Validate the model and preprocess a completion request.
271

272
273
        This is the authoritative implementation used directly by the
        GPU-less render server and delegated to by OpenAIServingCompletion.
274
275
276
277
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret
278
279
280
281
        result = await self.render_completion(request)
        if isinstance(result, ErrorResponse):
            return result
        generate_requests: list[GenerateRequest] = []
282
        for engine_input in result:
283
            prompt_components = extract_prompt_components(
284
                self.model_config, engine_input
285
286
287
288
289
290
            )
            token_ids = prompt_components.token_ids
            if not token_ids:
                return self.create_error_response("No token_ids rendered")
            token_ids = list(token_ids)

291
            input_length = extract_prompt_len(self.model_config, engine_input)
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            max_tokens = get_max_tokens(
                self.model_config.max_model_len,
                request.max_tokens,
                input_length,
                self.default_sampling_params,
                self.override_max_tokens,
            )
            params = request.to_sampling_params(
                max_tokens, self.default_sampling_params
            )

            request_id = f"cmpl-{random_uuid()}"

            generate_requests.append(
                GenerateRequest(
                    request_id=request_id,
                    token_ids=token_ids,
309
                    features=self._extract_mm_features(engine_input),
310
311
312
313
314
315
316
317
318
319
                    sampling_params=params,
                    model=request.model,
                    stream=bool(request.stream),
                    stream_options=(request.stream_options if request.stream else None),
                    cache_salt=request.cache_salt,
                    priority=request.priority,
                )
            )

        return generate_requests
320

321
322
323
    async def render_completion(
        self,
        request: CompletionRequest,
324
    ) -> list[EngineInput] | ErrorResponse:
325
326
327
328
329
        """Core preprocessing logic for completion requests (no model/engine check).

        Called directly by render_completion_request and delegated to by
        OpenAIServingCompletion.render_completion_request after its engine-aware checks.
        """
330
331
332
333
334
335
336
337
338
339
340
341
        # Return error for unsupported features.
        if request.suffix is not None:
            return self.create_error_response("suffix is not currently supported")

        if request.echo and request.prompt_embeds is not None:
            return self.create_error_response("Echo is unsupported with prompt embeds.")

        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
            return self.create_error_response(
                "prompt_logprobs is not compatible with prompt embeds."
            )

342
        engine_inputs = await self.preprocess_completion(
343
344
345
346
            request,
            prompt_input=request.prompt,
            prompt_embeds=request.prompt_embeds,
        )
347

348
        return engine_inputs
349

350
351
    @staticmethod
    def _extract_mm_features(
352
        engine_input: EngineInput,
353
354
355
356
357
    ) -> MultiModalFeatures | None:
        """Extract multimodal metadata from a rendered engine prompt.

        Returns ``None`` for text-only prompts.
        """
358
        if engine_input.get("type") != "multimodal":
359
360
            return None

361
362
363
        # At this point engine_input is a MultiModalInputs TypedDict.
        mm_hashes: MultiModalHashes = engine_input["mm_hashes"]  # type: ignore[typeddict-item]
        raw_placeholders: MultiModalPlaceholders = engine_input["mm_placeholders"]  # type: ignore[typeddict-item]
364
365
366
367
368
369
370
371
372
373
374
375
376

        mm_placeholders = {
            modality: [
                PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges
            ]
            for modality, ranges in raw_placeholders.items()
        }

        return MultiModalFeatures(
            mm_hashes=mm_hashes,
            mm_placeholders=mm_placeholders,
        )

377
378
379
380
381
    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
        should_include_tools: bool = True,
    ):
382
        """Build Harmony (GPT-OSS) messages and engine prompt from a chat request."""
383
384
385
386
387
388
389
390
391
392
393
394
        messages: list[OpenAIMessage] = []

        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
        _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]

        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
395
396
        if (reasoning_effort := request.reasoning_effort) == "none":
            raise ValueError(f"Harmony does not support {reasoning_effort=}")
397
        sys_msg = get_system_message(
398
            reasoning_effort=reasoning_effort,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
            browser_description=None,
            python_description=None,
            with_custom_tools=should_include_tools,
        )
        messages.append(sys_msg)

        # Add developer message.
        if request.tools:
            dev_msg = get_developer_message(
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
            )
            messages.append(dev_msg)

        # Add user message.
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
417
        engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
418

419
        return messages, [engine_input]
420
421
422
423
424
425
426
427

    def create_error_response(
        self,
        message: str | Exception,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
        param: str | None = None,
    ) -> ErrorResponse:
428
        return create_error_response(message, err_type, status_code, param)
429
430
431
432
433

    async def _check_model(
        self,
        request: Any,
    ) -> ErrorResponse | None:
434
        return await self.model_registry.check_model(request.model)
435

436
    def validate_chat_template(
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        self,
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
        trust_request_chat_template: bool,
    ) -> ErrorResponse | None:
        """Copied from OpenAIServing._validate_chat_template."""
        if not trust_request_chat_template and (
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
                "Refused request with untrusted chat template."
            )
        return None

457
    async def preprocess_completion(
458
459
460
461
        self,
        request: Any,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
462
463
        *,
        skip_mm_cache: bool = False,
464
    ) -> list[EngineInput]:
465
466
467
468
469
470
        """Copied from OpenAIServing._preprocess_completion."""
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))
471
        return await self.preprocess_cmpl(request, prompts, skip_mm_cache=skip_mm_cache)
472

473
    async def preprocess_cmpl(
474
475
476
        self,
        request: Any,
        prompts: Sequence[PromptType | bytes],
477
478
        *,
        skip_mm_cache: bool = False,
479
    ) -> list[EngineInput]:
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        """Copied from OpenAIServing._preprocess_cmpl."""
        renderer = self.renderer
        model_config = self.model_config

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = request.build_tok_params(model_config)

        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
502
            skip_mm_cache=skip_mm_cache,
503
504
        )

505
    async def preprocess_chat(
506
507
508
509
510
511
512
        self,
        request: Any,
        messages: list[Any],
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
        tool_dicts: list[dict[str, Any]] | None = None,
513
        tool_parser: type[ToolParser] | None = None,
514
        reasoning_parser: type[ReasoningParser] | None = None,
515
516
        *,
        skip_mm_cache: bool = False,
517
    ) -> tuple[list[ConversationMessage], list[EngineInput]]:
518
        """Copied from OpenAIServing._preprocess_chat."""
519
        renderer = self.renderer
520
        mm_config = self.model_config.multimodal_config
521
522
523
524
525
526
527
528
529
530
531
532

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
            ),
        )

        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
533
534
535
536
537
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
            default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
        )
538

539
        (conversation,), (engine_input,) = await renderer.render_chat_async(
540
541
542
543
544
545
546
547
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
548
            skip_mm_cache=skip_mm_cache,
549
550
        )

551
552
553
554
        if reasoning_parser is not None:
            tokenizer = renderer.get_tokenizer()
            request = reasoning_parser(tokenizer).adjust_request(request=request)

555
556
557
558
559
560
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
561
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
562
563
                    msg = (
                        "Tool usage is only supported "
564
565
                        "for Chat Completions API or Responses API requests, "
                        f"but got {type(request).__name__}"
566
567
568
                    )
                    raise NotImplementedError(msg)
                tokenizer = renderer.get_tokenizer()
569
                request = tool_parser(tokenizer, request.tools).adjust_request(
570
                    request=request
571
                )
572

573
        return conversation, [engine_input]