serving.py 20.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from http import HTTPStatus
from typing import Any

from openai_harmony import Message as OpenAIMessage

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
    ErrorResponse,
)
20
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
21
22
23
24
25
26
from vllm.entrypoints.openai.parser.harmony_utils import (
    get_developer_message,
    get_system_message,
    parse_chat_inputs_to_harmony_messages,
    render_for_completion,
)
27
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
28
29
30
31
32
33
34
35
36
from vllm.entrypoints.serve.disagg.protocol import (
    GenerateRequest,
    MultiModalFeatures,
    PlaceholderRangeInfo,
)
from vllm.entrypoints.utils import (
    create_error_response,
    get_max_tokens,
)
37
38
39
40
41
42
43
44
from vllm.inputs import (
    EngineInput,
    MultiModalHashes,
    MultiModalPlaceholders,
    PromptType,
    SingletonPrompt,
    tokens_input,
)
45
46
47
from vllm.logger import init_logger
from vllm.parser import ParserManager
from vllm.renderers import BaseRenderer, merge_kwargs
48
49
50
51
52
53
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
54
from vllm.tool_parsers import ToolParser
55
from vllm.utils import random_uuid
56
57
58
59
60
61
62
63
64
65
66
67
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt

logger = init_logger(__name__)


class OpenAIServingRender:
    def __init__(
        self,
        model_config: ModelConfig,
        renderer: BaseRenderer,
        io_processor: Any,
68
        model_registry: OpenAIModelRegistry,
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        enable_auto_tools: bool = False,
        exclude_tools_when_tool_choice_none: bool = False,
        tool_parser: str | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
        log_error_stack: bool = False,
    ) -> None:
        self.model_config = model_config
        self.renderer = renderer
        self.io_processor = io_processor
83
        self.model_registry = model_registry
84
85
86
87
88
89
90
91
        self.request_logger = request_logger
        self.chat_template = chat_template
        self.chat_template_content_format: ChatTemplateContentFormatOption = (
            chat_template_content_format
        )
        self.trust_request_chat_template = trust_request_chat_template
        self.enable_auto_tools = enable_auto_tools
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
92
93
94
95
        self.tool_parser: type[ToolParser] | None = ParserManager.get_tool_parser(
            tool_parser_name=tool_parser,
            enable_auto_tools=enable_auto_tools,
            model_name=model_config.model,
96
97
98
99
100
101
102
103
104
        )
        self.default_chat_template_kwargs: dict[str, Any] = (
            default_chat_template_kwargs or {}
        )
        self.log_error_stack = log_error_stack
        self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
        self.supports_browsing = False
        self.supports_code_interpreter = False

105
106
107
108
109
110
111
112
        self.default_sampling_params = model_config.get_diff_sampling_param()
        mc = model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )

113
114
115
    async def render_chat_request(
        self,
        request: ChatCompletionRequest,
116
    ) -> GenerateRequest | ErrorResponse:
117
        """Validate the model and preprocess a chat completion request.
118

119
120
        This is the authoritative implementation used directly by the
        GPU-less render server and delegated to by OpenAIServingChat.
121
122
123
124
125
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            logger.error("Error with model %s", error_check_ret)
            return error_check_ret
126
127
128
129
130
131
132
133
134
135

        if request.use_beam_search:
            return self.create_error_response(
                "Beam search is not supported by the render endpoint"
            )

        result = await self.render_chat(request)
        if isinstance(result, ErrorResponse):
            return result

136
        _, engine_inputs = result
137

138
        if len(engine_inputs) != 1:
139
            return self.create_error_response(
140
                f"Expected exactly 1 engine prompt, got {len(engine_inputs)}"
141
142
            )

143
        engine_input = engine_inputs[0]
144

145
        prompt_components = extract_prompt_components(self.model_config, engine_input)
146
147
148
149
150
        token_ids = prompt_components.token_ids
        if not token_ids:
            return self.create_error_response("No token_ids rendered")
        token_ids = list(token_ids)

151
        input_length = extract_prompt_len(self.model_config, engine_input)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        max_tokens = get_max_tokens(
            self.model_config.max_model_len,
            request.max_completion_tokens
            if request.max_completion_tokens is not None
            else request.max_tokens,
            input_length,
            self.default_sampling_params,
            self.override_max_tokens,
        )
        params = request.to_sampling_params(max_tokens, self.default_sampling_params)

        request_id = f"chatcmpl-{random_uuid()}"

        return GenerateRequest(
            request_id=request_id,
            token_ids=token_ids,
168
            features=self._extract_mm_features(engine_input),
169
170
171
172
173
174
175
            sampling_params=params,
            model=request.model,
            stream=bool(request.stream),
            stream_options=(request.stream_options if request.stream else None),
            cache_salt=request.cache_salt,
            priority=request.priority,
        )
176

177
178
179
    async def render_chat(
        self,
        request: ChatCompletionRequest,
180
    ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
181
182
183
184
185
        """Core preprocessing logic for chat requests (no model/engine check).

        Called directly by render_chat_request and delegated to by
        OpenAIServingChat.render_chat_request after its engine-aware checks.
        """
186
        tokenizer = self.renderer.tokenizer
187

188
        tool_parser = self.tool_parser
189

190
191
192
193
194
195
        if is_mistral_tokenizer(tokenizer):
            # because of issues with pydantic we need to potentially
            # re-serialize the tool_calls field of the request
            _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
            _mt.truncate_tool_call_ids(request)  # type: ignore[arg-type]
            _mt.validate_request_params(request)
196

197
198
199
200
201
202
        # Check if tool parsing is unavailable (common condition)
        tool_parsing_unavailable = (
            tool_parser is None
            and not is_mistral_tokenizer(tokenizer)
            and not self.use_harmony
        )
203

204
205
206
207
208
209
210
211
212
213
214
        # Validate tool_choice when tool parsing is required but unavailable
        if tool_parsing_unavailable and request.tool_choice not in (
            None,
            "none",
        ):
            if request.tool_choice == "auto" and not self.enable_auto_tools:
                # for hf tokenizers, "auto" tools requires
                # --enable-auto-tool-choice and --tool-call-parser
                return self.create_error_response(
                    '"auto" tool choice requires '
                    "--enable-auto-tool-choice and --tool-call-parser to be set"
215
                )
216
217
218
219
220
            elif request.tool_choice != "auto":
                # "required" or named tool requires tool parser
                return self.create_error_response(
                    f'tool_choice="{request.tool_choice}" requires '
                    "--tool-call-parser to be set"
221
                )
222
223
224
225
226
227
228
229
230
231

        if request.tools is None or (
            request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
        ):
            tool_dicts = None
        else:
            tool_dicts = [tool.model_dump() for tool in request.tools]

        if not self.use_harmony:
            # Common case.
232
            error_check_ret = self.validate_chat_template(
233
234
235
236
237
238
239
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret

240
            conversation, engine_inputs = await self.preprocess_chat(
241
242
243
244
245
246
247
248
249
250
251
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=self.default_chat_template_kwargs,
                tool_dicts=tool_dicts,
                tool_parser=tool_parser,
            )
        else:
            # For GPT-OSS.
            should_include_tools = tool_dicts is not None
252
            conversation, engine_inputs = self._make_request_with_harmony(
253
254
                request, should_include_tools
            )
255

256
        return conversation, engine_inputs
257
258
259
260

    async def render_completion_request(
        self,
        request: CompletionRequest,
261
    ) -> list[GenerateRequest] | ErrorResponse:
262
        """Validate the model and preprocess a completion request.
263

264
265
        This is the authoritative implementation used directly by the
        GPU-less render server and delegated to by OpenAIServingCompletion.
266
267
268
269
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret
270
271
272
273
        result = await self.render_completion(request)
        if isinstance(result, ErrorResponse):
            return result
        generate_requests: list[GenerateRequest] = []
274
        for engine_input in result:
275
            prompt_components = extract_prompt_components(
276
                self.model_config, engine_input
277
278
279
280
281
282
            )
            token_ids = prompt_components.token_ids
            if not token_ids:
                return self.create_error_response("No token_ids rendered")
            token_ids = list(token_ids)

283
            input_length = extract_prompt_len(self.model_config, engine_input)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            max_tokens = get_max_tokens(
                self.model_config.max_model_len,
                request.max_tokens,
                input_length,
                self.default_sampling_params,
                self.override_max_tokens,
            )
            params = request.to_sampling_params(
                max_tokens, self.default_sampling_params
            )

            request_id = f"cmpl-{random_uuid()}"

            generate_requests.append(
                GenerateRequest(
                    request_id=request_id,
                    token_ids=token_ids,
301
                    features=self._extract_mm_features(engine_input),
302
303
304
305
306
307
308
309
310
311
                    sampling_params=params,
                    model=request.model,
                    stream=bool(request.stream),
                    stream_options=(request.stream_options if request.stream else None),
                    cache_salt=request.cache_salt,
                    priority=request.priority,
                )
            )

        return generate_requests
312

313
314
315
    async def render_completion(
        self,
        request: CompletionRequest,
316
    ) -> list[EngineInput] | ErrorResponse:
317
318
319
320
321
        """Core preprocessing logic for completion requests (no model/engine check).

        Called directly by render_completion_request and delegated to by
        OpenAIServingCompletion.render_completion_request after its engine-aware checks.
        """
322
323
324
325
326
327
328
329
330
331
332
333
        # Return error for unsupported features.
        if request.suffix is not None:
            return self.create_error_response("suffix is not currently supported")

        if request.echo and request.prompt_embeds is not None:
            return self.create_error_response("Echo is unsupported with prompt embeds.")

        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
            return self.create_error_response(
                "prompt_logprobs is not compatible with prompt embeds."
            )

334
        engine_inputs = await self.preprocess_completion(
335
336
337
338
            request,
            prompt_input=request.prompt,
            prompt_embeds=request.prompt_embeds,
        )
339

340
        return engine_inputs
341

342
343
    @staticmethod
    def _extract_mm_features(
344
        engine_input: EngineInput,
345
346
347
348
349
    ) -> MultiModalFeatures | None:
        """Extract multimodal metadata from a rendered engine prompt.

        Returns ``None`` for text-only prompts.
        """
350
        if engine_input.get("type") != "multimodal":
351
352
            return None

353
354
355
        # At this point engine_input is a MultiModalInputs TypedDict.
        mm_hashes: MultiModalHashes = engine_input["mm_hashes"]  # type: ignore[typeddict-item]
        raw_placeholders: MultiModalPlaceholders = engine_input["mm_placeholders"]  # type: ignore[typeddict-item]
356
357
358
359
360
361
362
363
364
365
366
367
368

        mm_placeholders = {
            modality: [
                PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges
            ]
            for modality, ranges in raw_placeholders.items()
        }

        return MultiModalFeatures(
            mm_hashes=mm_hashes,
            mm_placeholders=mm_placeholders,
        )

369
370
371
372
373
    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
        should_include_tools: bool = True,
    ):
374
        """Build Harmony (GPT-OSS) messages and engine prompt from a chat request."""
375
376
377
378
379
380
381
382
383
384
385
386
        messages: list[OpenAIMessage] = []

        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
        _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]

        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
387
388
        if (reasoning_effort := request.reasoning_effort) == "none":
            raise ValueError(f"Harmony does not support {reasoning_effort=}")
389
        sys_msg = get_system_message(
390
            reasoning_effort=reasoning_effort,
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
            browser_description=None,
            python_description=None,
            with_custom_tools=should_include_tools,
        )
        messages.append(sys_msg)

        # Add developer message.
        if request.tools:
            dev_msg = get_developer_message(
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
            )
            messages.append(dev_msg)

        # Add user message.
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
409
        engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
410

411
        return messages, [engine_input]
412
413
414
415
416
417
418
419

    def create_error_response(
        self,
        message: str | Exception,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
        param: str | None = None,
    ) -> ErrorResponse:
420
        return create_error_response(message, err_type, status_code, param)
421
422
423
424
425

    async def _check_model(
        self,
        request: Any,
    ) -> ErrorResponse | None:
426
        return await self.model_registry.check_model(request.model)
427

428
    def validate_chat_template(
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        self,
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
        trust_request_chat_template: bool,
    ) -> ErrorResponse | None:
        """Copied from OpenAIServing._validate_chat_template."""
        if not trust_request_chat_template and (
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
                "Refused request with untrusted chat template."
            )
        return None

449
    async def preprocess_completion(
450
451
452
453
        self,
        request: Any,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
454
    ) -> list[EngineInput]:
455
456
457
458
459
460
        """Copied from OpenAIServing._preprocess_completion."""
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))
461
        return await self.preprocess_cmpl(request, prompts)
462

463
    async def preprocess_cmpl(
464
465
466
        self,
        request: Any,
        prompts: Sequence[PromptType | bytes],
467
    ) -> list[EngineInput]:
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        """Copied from OpenAIServing._preprocess_cmpl."""
        renderer = self.renderer
        model_config = self.model_config

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = request.build_tok_params(model_config)

        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

492
    async def preprocess_chat(
493
494
495
496
497
498
499
        self,
        request: Any,
        messages: list[Any],
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
        tool_dicts: list[dict[str, Any]] | None = None,
500
        tool_parser: type[ToolParser] | None = None,
501
    ) -> tuple[list[ConversationMessage], list[EngineInput]]:
502
        """Copied from OpenAIServing._preprocess_chat."""
503
        renderer = self.renderer
504
        mm_config = self.model_config.multimodal_config
505
506
507
508
509
510
511
512
513
514
515
516

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
            ),
        )

        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
517
518
519
520
521
        ).with_defaults(
            default_template_kwargs,
            default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
            default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
        )
522

523
        (conversation,), (engine_input,) = await renderer.render_chat_async(
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
540
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
541
542
                    msg = (
                        "Tool usage is only supported "
543
544
                        "for Chat Completions API or Responses API requests, "
                        f"but got {type(request).__name__}"
545
546
547
548
549
                    )
                    raise NotImplementedError(msg)
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]

550
        return conversation, [engine_input]