"vllm/collect_env.py" did not exist on "01bfb22b4112ee813185366ab26985d172661a61"
serving.py 87.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import time
7
8
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
9
from typing import Any, Final
10

11
import jinja2
12
import partial_json_parser
13
import regex as re
14
from fastapi import Request
15
from openai_harmony import Message as OpenAIMessage
16
from partial_json_parser.core.options import Allow
17

18
from vllm.engine.protocol import EngineClient
19
20
21
22
23
24
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
    ConversationMessage,
    get_history_tool_calls_cnt,
    make_tool_call_id,
)
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.chat_completion.protocol import (
27
28
29
30
31
32
33
34
35
36
    ChatCompletionLogProb,
    ChatCompletionLogProbs,
    ChatCompletionLogProbsContent,
    ChatCompletionNamedToolChoiceParam,
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
37
38
)
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
39
    TokenState,
40
41
42
    extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
43
44
45
46
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ErrorResponse,
47
    FunctionCall,
48
49
50
51
52
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    ToolCall,
    UsageInfo,
)
53
from vllm.entrypoints.openai.engine.serving import (
54
55
56
57
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
58
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
59
60
61
62
63
64
65
66
67
from vllm.entrypoints.openai.parser.harmony_utils import (
    get_developer_message,
    get_stop_tokens_for_assistant_actions,
    get_streamable_parser_for_assistant,
    get_system_message,
    parse_chat_inputs_to_harmony_messages,
    parse_chat_output,
    render_for_completion,
)
68
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
69
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
70
from vllm.inputs.data import TokensPrompt
71
from vllm.inputs.parse import get_prompt_components
72
from vllm.logger import init_logger
73
from vllm.logprobs import Logprob
74
from vllm.outputs import CompletionOutput, RequestOutput
75
from vllm.sampling_params import BeamSearchParams, SamplingParams
76
77
78
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
    MistralTokenizer,
79
80
81
82
    maybe_serialize_tool_calls,
    truncate_tool_call_ids,
    validate_request_params,
)
83
84
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
85
from vllm.tool_parsers.utils import partial_json_loads
86
from vllm.utils.collection_utils import as_list
87
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
88
89
90
91
92

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):
93
94
95
    def __init__(
        self,
        engine_client: EngineClient,
96
        models: OpenAIServingModels,
97
98
        response_role: str,
        *,
99
100
        request_logger: RequestLogger | None,
        chat_template: str | None,
101
        chat_template_content_format: ChatTemplateContentFormatOption,
102
        trust_request_chat_template: bool = False,
103
        return_tokens_as_token_ids: bool = False,
104
        reasoning_parser: str = "",
105
        enable_auto_tools: bool = False,
106
        exclude_tools_when_tool_choice_none: bool = False,
107
        tool_parser: str | None = None,
108
        enable_prompt_tokens_details: bool = False,
109
        enable_force_include_usage: bool = False,
110
        enable_log_outputs: bool = False,
111
        enable_log_deltas: bool = True,
112
        log_error_stack: bool = False,
113
        default_chat_template_kwargs: dict[str, Any] | None = None,
114
    ) -> None:
115
116
117
118
119
120
121
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
            log_error_stack=log_error_stack,
        )
122

123
        self.response_role = response_role
124
125
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
126
        self.trust_request_chat_template = trust_request_chat_template
127
        self.default_chat_template_kwargs = default_chat_template_kwargs or {}
128
        self.enable_log_outputs = enable_log_outputs
129
        self.enable_log_deltas = enable_log_deltas
130

131
132
133
        # set up logits processors
        self.logits_processors = self.model_config.logits_processors

134
135
136
137
        # set up reasoning parser
        self.reasoning_parser = self._get_reasoning_parser(
            reasoning_parser_name=reasoning_parser
        )
138
139
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
140
141
        self.tool_parser = self._get_tool_parser(
            tool_parser_name=tool_parser, enable_auto_tools=enable_auto_tools
142
143
        )
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
144

145
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
146
        self.enable_force_include_usage = enable_force_include_usage
147
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
148
        self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
149
150
151
152
        if self.use_harmony:
            if "stop_token_ids" not in self.default_sampling_params:
                self.default_sampling_params["stop_token_ids"] = []
            self.default_sampling_params["stop_token_ids"].extend(
153
154
                get_stop_tokens_for_assistant_actions()
            )
155

156
157
158
159
160
161
162
163
164
165
        # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
        hf_overrides = getattr(self.model_config, "hf_overrides", None)
        if self.model_config.hf_text_config.model_type == "kimi_k2" or (
            isinstance(hf_overrides, dict)
            and hf_overrides.get("model_type") == "kimi_k2"
        ):
            self.tool_call_id_type = "kimi_k2"
        else:
            self.tool_call_id_type = "random"

166
167
168
169
170
171
172
173
174
175
        # NOTE(woosuk): While OpenAI's chat completion API supports browsing
        # for some models, currently vLLM doesn't support it. Please use the
        # Responses API instead.
        self.supports_browsing = False
        self.browser_tool = None
        # NOTE(woosuk): Chat completion API does not support code interpreter.
        # Please use the Responses API instead.
        self.supports_code_interpreter = False
        self.python_tool = None

176
177
178
179
180
181
182
183
184
185
186
187
    async def warmup(self) -> None:
        """
        Warm up the chat template processing to avoid first-request latency.

        This method triggers Jinja2 template compilation and content format
        detection that would otherwise happen on the first real request,
        causing increased latency on the first request.
        """
        logger.info("Warming up chat template processing...")
        start_time = time.perf_counter()

        try:
188
            renderer = self.engine_client.renderer
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

            # Create a minimal dummy request
            dummy_request = ChatCompletionRequest(
                messages=[{"role": "user", "content": "warmup"}],
                model=None,
                max_completion_tokens=1,
            )

            # Call _preprocess_chat to trigger template compilation
            # This forces:
            # 1. Chat template content format detection
            # 2. Jinja2 template compilation
            # 3. Tokenizer initialization for chat
            await self._preprocess_chat(
                dummy_request,
204
                renderer,
205
206
207
208
209
210
211
212
                dummy_request.messages,
                chat_template=self.chat_template,
                chat_template_content_format=self.chat_template_content_format,
                add_generation_prompt=True,
                continue_final_message=False,
                tool_dicts=None,
                documents=None,
                chat_template_kwargs=None,
213
                default_chat_template_kwargs=self.default_chat_template_kwargs,
214
215
216
217
218
219
220
221
222
223
224
                tool_parser=None,
                add_special_tokens=False,
            )

            elapsed = (time.perf_counter() - start_time) * 1000
            logger.info("Chat template warmup completed in %.1fms", elapsed)

        except Exception:
            # Log but don't fail server startup if warmup fails
            logger.exception("Chat template warmup failed")

225
    async def render_chat_request(
226
227
        self,
        request: ChatCompletionRequest,
228
    ) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse:
229
        """
230
        render chat request by validating and preprocessing inputs.
231

232
233
234
        Returns:
            A tuple of (conversation, engine_prompts) on success,
            or an ErrorResponse on failure.
235
236
237
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
238
            logger.error("Error with model %s", error_check_ret)
239
240
            return error_check_ret

241
242
243
244
245
246
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

247
        try:
248
249
            renderer = self.engine_client.renderer
            tokenizer = renderer.tokenizer
250

251
252
            tool_parser = self.tool_parser

253
            if isinstance(tokenizer, MistralTokenizer):
254
255
256
                # because of issues with pydantic we need to potentially
                # re-serialize the tool_calls field of the request
                # for more info: see comment in `maybe_serialize_tool_calls`
257
258
                maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
                truncate_tool_call_ids(request)  # type: ignore[arg-type]
259
                validate_request_params(request)
260

261
262
263
            # Check if tool parsing is unavailable (common condition)
            tool_parsing_unavailable = (
                tool_parser is None
264
265
                and not isinstance(tokenizer, MistralTokenizer)
                and not self.use_harmony
266
267
268
269
270
271
            )

            # Validate tool_choice when tool parsing is required but unavailable
            if tool_parsing_unavailable and request.tool_choice not in (
                None,
                "none",
272
            ):
273
274
275
276
277
278
279
280
281
282
283
284
285
                if request.tool_choice == "auto" and not self.enable_auto_tools:
                    # for hf tokenizers, "auto" tools requires
                    # --enable-auto-tool-choice and --tool-call-parser
                    return self.create_error_response(
                        '"auto" tool choice requires '
                        "--enable-auto-tool-choice and --tool-call-parser to be set"
                    )
                elif request.tool_choice != "auto":
                    # "required" or named tool requires tool parser
                    return self.create_error_response(
                        f'tool_choice="{request.tool_choice}" requires '
                        "--tool-call-parser to be set"
                    )
286

287
288
289
290
            if request.tools is None or (
                request.tool_choice == "none"
                and self.exclude_tools_when_tool_choice_none
            ):
291
292
293
                tool_dicts = None
            else:
                tool_dicts = [tool.model_dump() for tool in request.tools]
294

295
296
            if not self.use_harmony:
                # Common case.
297
298
299
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
300
                    trust_request_chat_template=self.trust_request_chat_template,
301
302
303
                )
                if error_check_ret is not None:
                    return error_check_ret
304
305
306
307

                chat_template_kwargs = request.chat_template_kwargs or {}
                chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)

308
                conversation, engine_prompts = await self._preprocess_chat(
309
                    request,
310
                    renderer,
311
                    request.messages,
312
                    chat_template=request.chat_template or self.chat_template,
313
                    chat_template_content_format=self.chat_template_content_format,
314
315
316
317
                    add_generation_prompt=request.add_generation_prompt,
                    continue_final_message=request.continue_final_message,
                    tool_dicts=tool_dicts,
                    documents=request.documents,
318
                    chat_template_kwargs=chat_template_kwargs,
319
                    default_chat_template_kwargs=self.default_chat_template_kwargs,
320
321
322
323
324
                    tool_parser=tool_parser,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
                # For GPT-OSS.
325
326
327
328
                should_include_tools = tool_dicts is not None
                conversation, engine_prompts = self._make_request_with_harmony(
                    request, should_include_tools
                )
329
        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
330
            logger.exception("Error in preprocessing prompt inputs")
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
            return self.create_error_response(e)

        return conversation, engine_prompts

    async def create_chat_completion(
        self,
        request: ChatCompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
        """
        Chat Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        Chat Completion API.
        """
        result = await self.render_chat_request(request)
        if isinstance(result, ErrorResponse):
            return result

        conversation, engine_prompts = result
352

353
354
355
        request_id = (
            f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
        )
356
357
358
359
360

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

361
362
363
364
365
366
367
368
369
370
        try:
            lora_request = self._maybe_get_adapters(
                request, supports_default_mm_loras=True
            )

            model_name = self.models.model_name(lora_request)
        except (ValueError, TypeError, RuntimeError) as e:
            logger.exception("Error preparing request components")
            return self.create_error_response(e)

371
372
373
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

374
        # Schedule the request and get the result generator.
375
        generators: list[AsyncGenerator[RequestOutput, None]] = []
376
        try:
377
            for i, engine_prompt in enumerate(engine_prompts):
378
379
                prompt_text, _, _ = get_prompt_components(engine_prompt)

380
381
382
383
384
                # If we are creating sub requests for multiple prompts, ensure that they
                # have unique request ids.
                sub_request_id = (
                    request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
                )
385
386
387
388

                max_tokens = get_max_tokens(
                    max_model_len=self.max_model_len,
                    request=request,
389
                    prompt=engine_prompt,
390
391
                    default_sampling_params=self.default_sampling_params,
                )
392

393
                sampling_params: SamplingParams | BeamSearchParams
394
395
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
396
397
                        max_tokens, self.default_sampling_params
                    )
398
399
                else:
                    sampling_params = request.to_sampling_params(
400
401
402
403
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
404
405
406
407
                    validate_logits_processors_parameters(
                        self.logits_processors,
                        sampling_params,
                    )
408

409
                self._log_inputs(
410
                    sub_request_id,
411
                    engine_prompt,
412
413
414
                    params=sampling_params,
                    lora_request=lora_request,
                )
415

416
417
418
419
420
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
421
422

                if isinstance(sampling_params, BeamSearchParams):
423
                    generator = self.beam_search(
424
                        prompt=engine_prompt,
425
                        request_id=sub_request_id,
426
                        params=sampling_params,
427
                        lora_request=lora_request,
428
                        trace_headers=trace_headers,
429
430
                    )
                else:
431
                    engine_request, tokenization_kwargs = await self._process_inputs(
432
                        sub_request_id,
433
434
435
436
437
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
438
                        data_parallel_rank=data_parallel_rank,
439
                    )
440

441
                    generator = self.engine_client.generate(
442
                        engine_request,
443
                        sampling_params,
444
                        sub_request_id,
445
446
447
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
448
449
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
450
                        data_parallel_rank=data_parallel_rank,
451
452
453
                    )

                generators.append(generator)
454
        except ValueError as e:
455
            return self.create_error_response(e)
456

457
        assert len(generators) == 1
458
        (result_generator,) = generators
459

460
        # Streaming response
461
        tokenizer = self.renderer.tokenizer
462
        assert tokenizer is not None
463

464
465
        if request.stream:
            return self.chat_completion_stream_generator(
466
467
468
469
470
471
472
                request,
                result_generator,
                request_id,
                model_name,
                conversation,
                tokenizer,
                request_metadata,
473
            )
474

475
476
        try:
            return await self.chat_completion_full_generator(
477
478
479
480
481
482
483
484
                request,
                result_generator,
                request_id,
                model_name,
                conversation,
                tokenizer,
                request_metadata,
            )
485
486
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
487
        except ValueError as e:
488
            return self.create_error_response(e)
489
490
491
492

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
493
        return request.messages[-1]["role"]
494

495
    @staticmethod
496
    def _bracket_level(s: str, opening="{", closing="}") -> int:
497
498
499
500
501
502
503
504
505
506
507
508
        """
        Calculate the current level of nested brackets in a given string.
        """
        level = 0
        for char in s:
            if char == opening:
                level += 1
            elif char == closing:
                level -= 1
        return level

    @staticmethod
509
    def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]:
510
511
512
513
514
515
516
517
518
        # remove last '},' of the tool definition stemming from the
        # "name"/"parameters" outer object or closing ']' of the tool list
        # count occurrences of opening and closing curly braces and
        # once level 0 is reached stop outputting text
        # if 0 is reached while parsing the delta_text we know the current
        # tool will finish in this current iteration
        bracket_level = OpenAIServingChat._bracket_level(previous_text)
        updated_delta, passed_zero = "", False
        for c in delta_text:
519
            if c == "{":
520
521
                bracket_level += 1
                passed_zero = bracket_level == 0
522
            elif c == "}":
523
524
525
526
527
528
529
                bracket_level -= 1
                passed_zero = bracket_level == 0

            if bracket_level != 0:
                updated_delta += c
            else:
                # if a comma is reached at level 0 we can stop
530
                if c == ",":
531
532
533
534
535
536
                    break
        return updated_delta, passed_zero

    def extract_tool_call_required_streaming(
        self,
        previous_text: str,
537
        current_text: str | None,
538
539
        delta_text: str,
        function_name_returned: bool,
540
541
        tool_call_idx: int | None = None,
    ) -> tuple[DeltaMessage | None, bool]:
542
543
544
        if current_text is None or current_text == "":
            # if the current text is empty, we cannot parse it
            return None, function_name_returned
545
        try:
546
547
548
549
550
551
            flags = Allow.ALL
            obj, _ = partial_json_loads(current_text, flags)
        except (
            partial_json_parser.core.exceptions.MalformedJSON,
            json.JSONDecodeError,
        ):
552
            logger.debug("not enough tokens to parse into JSON yet")
553
554
555
556
557
558
559
560
561
562
            obj = None

        # check if the current text is a valid array
        # containing a partial tool calling object
        # if not repeat
        if obj is None or not isinstance(obj, list) or not len(obj) > 0:
            function_name_returned = False
            delta_message = None
        else:
            _, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
563
564
                delta_text, previous_text
            )
565
566
567
568
            # take the last tool call from the generated list
            current_tool_call = obj[-1]

            # once parameters have been generated the name is complete as well
569
570
571
            if not finishes_previous_tool and (
                "name" not in current_tool_call or "parameters" not in current_tool_call
            ):
572
573
574
575
576
                function_name_returned = False
                delta_message = None
            else:
                if not function_name_returned:
                    # get partly generated arguments from the latest tool call
577
578
579
                    param_match = re.search(
                        r'.*"parameters":\s*(.*)', current_text, re.DOTALL
                    )
580
581
                    arguments = param_match.group(1) if param_match else ""
                    arguments, _ = OpenAIServingChat._filter_delta_text(
582
583
                        arguments, previous_text
                    )
584
585
586
587

                    # if this iteration finishes a previous tool call but a
                    # new incomplete tool is already generated, take the
                    # previous from the list
588
                    if finishes_previous_tool and "parameters" not in current_tool_call:
589
590
591
                        current_tool_call = obj[-2]

                    function_name_returned = True
592
593
594
                    tool_call_id = make_tool_call_id(
                        id_type=self.tool_call_id_type,
                        func_name=current_tool_call["name"],
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                        idx=tool_call_idx,
                    )
                    delta_message = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                id=tool_call_id,
                                function=DeltaFunctionCall(
                                    name=current_tool_call["name"], arguments=arguments
                                ),
                                index=len(obj) - 1,
                                type="function",
                            )
                        ]
                    )
609
610
611

                else:
                    delta_text, _ = OpenAIServingChat._filter_delta_text(
612
613
                        delta_text, previous_text
                    )
614
615

                    if delta_text != "":
616
617
618
619
620
621
622
623
624
625
626
627
628
                        delta_message = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    function=DeltaFunctionCall(
                                        # OpenAI API returns None
                                        # instead of name every time
                                        name=None,
                                        arguments=delta_text,
                                    ),
                                    index=len(obj) - 1,
                                )
                            ]
                        )
629
630
631
632
633
                    else:
                        delta_message = None

        return delta_message, function_name_returned

634
    async def chat_completion_stream_generator(
635
636
637
638
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
639
        model_name: str,
640
        conversation: list[ConversationMessage],
641
        tokenizer: TokenizerLike,
642
        request_metadata: RequestResponseMetadata,
643
    ) -> AsyncGenerator[str, None]:
644
645
        from vllm.tokenizers.mistral import MistralTokenizer

646
        created_time = int(time.time())
647
        chunk_object_type: Final = "chat.completion.chunk"
648
        first_iteration = True
649
650

        # Send response for each token for each request.n (index)
651
652
653
        num_choices = 1 if request.n is None else request.n
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices
654
        num_prompt_tokens = 0
655
        num_cached_tokens = None
656
657
        if self.use_harmony:
            harmony_parsers = [
658
                get_streamable_parser_for_assistant() for _ in range(num_choices)
659
            ]
660
661
            harmony_tools_streamed = [False] * num_choices
        tools_streamed = [False] * num_choices
662
663
664
665
666
667
668
669
670

        if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
            tool_choice_function_name = request.tool_choice.function.name
        else:
            tool_choice_function_name = None

        # Determine whether tools are in use with "auto" tool choice
        tool_choice_auto = (
            not tool_choice_function_name
671
672
            and self._should_stream_with_auto_tool_parsing(request)
        )
673

674
        all_previous_token_ids: list[list[int]] | None
675
        function_name_returned = [False] * num_choices
676
        if self.tool_call_id_type == "kimi_k2":
677
678
679
            history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
        else:
            history_tool_call_cnt = 0
680

681
682
683
        # Always track previous_texts for comprehensive output logging
        previous_texts = [""] * num_choices

684
685
        # Only one of these will be used, thus previous_texts and
        # all_previous_token_ids will not be used twice in the same iteration.
686
        if tool_choice_auto or self.reasoning_parser:
687
688
            # These are only required in "auto" tool choice case
            all_previous_token_ids = [[]] * num_choices
689
690
691
            # For reasoning parser and tool call all enabled
            added_content_delta_arr = [False] * num_choices
            reasoning_end_arr = [False] * num_choices
692
        else:
693
            all_previous_token_ids = None
694

695
        try:
696
            if self.reasoning_parser:
697
698
699
700
701
                if tokenizer is None:
                    raise ValueError(
                        "Tokenizer not available when `skip_tokenizer_init=True`"
                    )

702
703
704
705
706
                # Pass the same chat template kwargs as used in tokenization
                chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
                    request.chat_template_kwargs,
                    self.default_chat_template_kwargs,
                )
707
708
                reasoning_parser = self.reasoning_parser(
                    tokenizer,
709
                    chat_template_kwargs=chat_template_kwargs or {},  # type: ignore[call-arg]
710
                )
711
712
713
714
715
716
        except RuntimeError as e:
            logger.exception("Error in reasoning parser creation.")
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return
717
718
719
        # Prepare the tool parser if it's needed
        try:
            if tool_choice_auto and self.tool_parser:
720
721
722
723
724
                if tokenizer is None:
                    raise ValueError(
                        "Tokenizer not available when `skip_tokenizer_init=True`"
                    )

725
                tool_parsers: list[ToolParser | None] = [
726
727
728
729
                    self.tool_parser(tokenizer)
                ] * num_choices
            else:
                tool_parsers = [None] * num_choices
730
        except Exception as e:
731
            logger.exception("Error in tool parser creation.")
732
            data = self.create_streaming_error_response(e)
733
734
735
736
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return

737
        stream_options = request.stream_options
738
739
740
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
741

742
743
        try:
            async for res in result_generator:
744
745
                if res.prompt_token_ids is not None:
                    num_prompt_tokens = len(res.prompt_token_ids)
746
747
                    if res.encoder_prompt_token_ids is not None:
                        num_prompt_tokens += len(res.encoder_prompt_token_ids)
748

749
750
751
752
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
753
                    num_cached_tokens = res.num_cached_tokens
754
755
                    # Send first response for each request.n (index) with
                    # the role
756
                    role = self.get_chat_request_role(request)
757
758
759

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
760
                    for i in range(num_choices):
761
762
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
763
764
765
766
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
767
                            logprobs=None,
768
769
                            finish_reason=None,
                        )
770
771

                        # return prompt_token_ids at the first chunk ever
772
773
774
775
776
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
777
                            model=model_name,
778
779
780
781
782
783
                            prompt_token_ids=(
                                res.prompt_token_ids
                                if request.return_token_ids
                                else None
                            ),
                        )
784

785
786
787
788
789
                        # if continuous usage stats are requested, add it
                        if include_continuous_usage:
                            chunk.usage = UsageInfo(
                                prompt_tokens=num_prompt_tokens,
                                completion_tokens=0,
790
791
                                total_tokens=num_prompt_tokens,
                            )
792

793
794
795
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

796
797
                    # Send response to echo the input portion of the
                    # last message
798
                    if request.echo:
799
                        last_msg_content: str | list[dict[str, str]] = ""
800
801
802
803
804
                        if (
                            conversation
                            and "content" in conversation[-1]
                            and conversation[-1].get("role") == role
                        ):
805
                            last_msg_content = conversation[-1]["content"] or ""
806
807

                        if last_msg_content:
808
                            for i in range(num_choices):
809
810
811
812
813
814
                                choice_data = ChatCompletionResponseStreamChoice(
                                    index=i,
                                    delta=DeltaMessage(content=last_msg_content),
                                    logprobs=None,
                                    finish_reason=None,
                                )
815
816
817
818
819
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
820
821
                                    model=model_name,
                                )
822
823
824
825
                                if include_continuous_usage:
                                    chunk.usage = UsageInfo(
                                        prompt_tokens=num_prompt_tokens,
                                        completion_tokens=0,
826
827
                                        total_tokens=num_prompt_tokens,
                                    )
828

829
                                data = chunk.model_dump_json(exclude_unset=True)
830
831
832
833
834
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index
835
                    tool_parser = tool_parsers[i]
836
837
838
839

                    if finish_reason_sent[i]:
                        continue

840
                    if request.logprobs and request.top_logprobs is not None:
841
                        assert output.logprobs is not None, "Did not output logprobs"
842
                        logprobs = self._create_chat_logprobs(
843
844
                            token_ids=output.token_ids,
                            top_logprobs=output.logprobs,
845
                            tokenizer=tokenizer,
846
                            num_output_top_logprobs=request.top_logprobs,
847
                            return_as_token_id=request.return_tokens_as_token_ids,
848
849
850
851
                        )
                    else:
                        logprobs = None

852
853
                    if self.use_harmony:
                        harmony_parser = harmony_parsers[i]
854
                        prev_recipient = harmony_parser.current_recipient
855
856
857

                        # Track accumulated content per token with their state
                        token_states: list[TokenState] = []
858
859
                        for token_id in output.token_ids:
                            harmony_parser.process(token_id)
860
861
862
863
864
865
866
867
868
                            token_delta = harmony_parser.last_content_delta or ""
                            token_states.append(
                                TokenState(
                                    harmony_parser.current_channel,
                                    harmony_parser.current_recipient,
                                    token_delta,
                                )
                            )
                        delta_text = "".join(delta for _, _, delta in token_states)
869
                        cur_channel = harmony_parser.current_channel
870

871
872
873
874
875
                        # handle the case where several tokens where generated at once
                        # including the final token, leading to a delta in the text
                        # but the current channel to be empty (start state)
                        if not cur_channel and delta_text:
                            cur_channel = "final"
876
877
                    else:
                        delta_text = output.text
878

879
880
881
882
883
                    if (
                        not delta_text
                        and not output.token_ids
                        and not previous_num_tokens[i]
                    ):
884
885
886
                        # Chunked prefill case, don't return empty chunks
                        continue

887
                    delta_message: DeltaMessage | None
888

889
                    # just update previous_texts and previous_token_ids
890
                    if tool_choice_auto or self.reasoning_parser:
891
892
893
894
895
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_text = previous_texts[i]
                        previous_token_ids = all_previous_token_ids[i]
                        current_text = previous_text + delta_text
896
897
                        # avoid the None + list error.
                        if previous_token_ids:
898
                            current_token_ids = previous_token_ids + as_list(
899
900
                                output.token_ids
                            )
901
                        else:
902
                            current_token_ids = as_list(output.token_ids)
903

904
                    if self.use_harmony:
905
906
907
                        delta_message, tools_streamed_flag = (
                            extract_harmony_streaming_delta(
                                harmony_parser=harmony_parser,
908
                                token_states=token_states,
909
910
911
912
913
                                prev_recipient=prev_recipient,
                                include_reasoning=request.include_reasoning,
                            )
                        )
                        harmony_tools_streamed[i] |= tools_streamed_flag
914
                    # handle streaming deltas for tools with named tool_choice
915
                    elif tool_choice_function_name:
916
917
918
919
920
921
922
                        if (
                            self.reasoning_parser
                            and not reasoning_end_arr[i]
                            and not reasoning_parser.is_reasoning_end(
                                previous_token_ids
                            )
                        ):
923
924
                            assert reasoning_parser is not None
                            delta_message = (
925
                                reasoning_parser.extract_reasoning_streaming(
926
927
928
929
930
931
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output.token_ids,
932
933
                                )
                            )
934
935
936
937
                            # When encountering think end id in delta_token_ids
                            # or think end id in prompt_token_ids
                            # i.e {"enable_thinking": False},
                            # set reasoning status to end.
938
                            # Only keep 'content', remove 'reasoning'.
939
                            if reasoning_parser.is_reasoning_end(
940
941
942
943
944
945
946
                                as_list(output.token_ids)
                            ) or (
                                res.prompt_token_ids
                                and reasoning_parser.is_reasoning_end(
                                    res.prompt_token_ids
                                )
                            ):
947
                                reasoning_end_arr[i] = True
948
949
950
951
952
953
954
955
                                if delta_message and delta_message.content:
                                    # This need to be added to next `delta_text`
                                    current_text = delta_message.content
                                    delta_message.content = None
                                else:
                                    current_text = ""
                        else:
                            # Just to add remaining `content`
956
                            if self.reasoning_parser:
957
958
959
                                delta_text = previous_text + delta_text
                                current_text = ""

960
961
                            if function_name_returned[i]:
                                delta_tool_call = DeltaToolCall(
962
963
964
                                    function=DeltaFunctionCall(arguments=delta_text),
                                    index=i,
                                )
965
                            else:
966
967
968
969
970
971
972
973
974
                                # Generate ID based on tokenizer type
                                if isinstance(tokenizer, MistralTokenizer):
                                    tool_call_id = MistralToolCall.generate_random_id()
                                else:
                                    tool_call_id = make_tool_call_id(
                                        id_type=self.tool_call_id_type,
                                        func_name=tool_choice_function_name,
                                        idx=history_tool_call_cnt,
                                    )
975
                                delta_tool_call = DeltaToolCall(
976
                                    id=tool_call_id,
977
978
979
                                    type="function",
                                    function=DeltaFunctionCall(
                                        name=tool_choice_function_name,
980
981
982
983
                                        arguments=delta_text,
                                    ),
                                    index=i,
                                )
984
985
                                function_name_returned[i] = True

986
987
988
989
990
                            delta_message = DeltaMessage(
                                tool_calls=[
                                    delta_tool_call,
                                ]
                            )
991
                            tools_streamed[i] = True
992

993
994
995
996
997
                    elif request.tool_choice == "required":
                        assert previous_texts is not None
                        previous_text = previous_texts[i]
                        current_text = previous_text + delta_text
                        fn_name_returned = function_name_returned[i]
998
999
1000
1001
1002
1003
1004
1005
1006
                        output_token_ids = as_list(output.token_ids)

                        if (
                            self.reasoning_parser is not None
                            and not reasoning_end_arr[i]
                            and res.prompt_token_ids
                            and reasoning_parser.is_reasoning_end(res.prompt_token_ids)
                        ):
                            reasoning_end_arr[i] = True
1007

1008
1009
                        if self.reasoning_parser and not reasoning_end_arr[i]:
                            delta_message = (
1010
                                reasoning_parser.extract_reasoning_streaming(
1011
1012
1013
1014
1015
1016
1017
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output_token_ids,
                                )
1018
                            )
1019
1020
1021
1022
1023
1024
1025
1026
1027
                            if reasoning_parser.is_reasoning_end(output_token_ids):
                                reasoning_end_arr[i] = True
                                if delta_message and delta_message.content:
                                    current_text = delta_message.content
                                    delta_message.content = None
                                else:
                                    # reasoning ended
                                    current_text = ""

1028
                        else:
1029
                            # either finished reasoning or no reasoning at all
1030
                            content = current_text
1031
1032
1033
1034
1035
1036
1037
1038
1039

                            delta_message, function_name_returned[i] = (
                                self.extract_tool_call_required_streaming(
                                    previous_text=previous_text,
                                    current_text=content,
                                    delta_text=delta_text,
                                    function_name_returned=fn_name_returned,
                                    tool_call_idx=history_tool_call_cnt,
                                )
1040
                            )
1041
1042
1043
1044
1045
1046
1047
                            if (
                                delta_message
                                and delta_message.tool_calls
                                and delta_message.tool_calls[0].id is not None
                            ):
                                history_tool_call_cnt += 1
                                tools_streamed[i] = True
1048

1049
1050
                    # handle streaming deltas for tools with "auto" tool choice
                    # and reasoning parser
1051
                    elif tool_choice_auto and self.reasoning_parser:
1052
1053
1054
1055
                        assert tool_parser is not None
                        assert reasoning_parser is not None
                        assert added_content_delta_arr is not None
                        assert reasoning_end_arr is not None
1056
                        output_token_ids = as_list(output.token_ids)
1057
                        if not reasoning_end_arr[i]:
1058
1059
1060
                            # When encountering think end id in prompt_token_ids
                            # i.e {"enable_thinking": False},
                            # set reasoning status to end.
1061
1062
1063
1064
1065
1066
                            if (
                                res.prompt_token_ids
                                and reasoning_parser.is_reasoning_end(
                                    res.prompt_token_ids
                                )
                            ):
1067
                                reasoning_end_arr[i] = True
1068
                                current_token_ids = output_token_ids
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
                                # Don't update current_text, keep it as is from delta
                            else:
                                delta_message = (
                                    reasoning_parser.extract_reasoning_streaming(
                                        previous_text,
                                        current_text,
                                        delta_text,
                                        previous_token_ids,
                                        current_token_ids,
                                        output_token_ids,
1079
1080
                                    )
                                )
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

                                # When encountering think end id in delta_token_ids,
                                # set reasoning status to end.
                                # Remove the text and token ids related
                                # to 'reasoning'.
                                if reasoning_parser.is_reasoning_end(output_token_ids):
                                    reasoning_end_arr[i] = True
                                    current_token_ids = (
                                        reasoning_parser.extract_content_ids(
                                            output_token_ids
                                        )
                                    )
                                    if delta_message and delta_message.content:
                                        current_text = delta_message.content
                                        delta_message.content = None
                                    else:
                                        current_text = ""
1098
1099

                        # handle tool calls only after reasoning is done,
1100
                        if reasoning_end_arr[i]:
1101
                            delta_token_ids = output_token_ids
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
                            # First time to tool call,
                            # add the remaining text and token ids
                            # to delta from previous
                            if not added_content_delta_arr[i]:
                                added_content_delta_arr[i] = True
                                previous_text = ""
                                previous_token_ids = []
                                delta_text = current_text
                                delta_token_ids = current_token_ids

1112
                            delta_message = tool_parser.extract_tool_calls_streaming(
1113
1114
                                previous_text=previous_text,
                                current_text=current_text,
1115
                                delta_text=delta_text,
1116
1117
                                previous_token_ids=previous_token_ids,
                                current_token_ids=current_token_ids,
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
                                delta_token_ids=delta_token_ids,
                                request=request,
                            )
                            if delta_message and delta_message.tool_calls:
                                tools_streamed[i] = True
                    # when only tool calls
                    elif tool_choice_auto:
                        assert tool_parser is not None
                        delta_message = tool_parser.extract_tool_calls_streaming(
                            previous_text=previous_text,
                            current_text=current_text,
                            delta_text=delta_text,
                            previous_token_ids=previous_token_ids,
                            current_token_ids=current_token_ids,
                            delta_token_ids=output.token_ids,
                            request=request,
                        )
1135
1136
                        if delta_message and delta_message.tool_calls:
                            tools_streamed[i] = True
1137

1138
                    # when only reasoning
1139
                    elif self.reasoning_parser:
1140
1141
1142
1143
1144
1145
1146
                        delta_message = reasoning_parser.extract_reasoning_streaming(
                            previous_text,
                            current_text,
                            delta_text,
                            previous_token_ids,
                            current_token_ids,
                            output.token_ids,
1147
                        )
1148
                    # handle streaming just a content delta
1149
1150
1151
                    else:
                        delta_message = DeltaMessage(content=delta_text)

1152
                    # update the previous values for the next iteration
1153
1154
1155
                    if (
                        tool_choice_auto or self.reasoning_parser
                    ) and not self.use_harmony:
1156
1157
1158
1159
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_texts[i] = current_text
                        all_previous_token_ids[i] = current_token_ids
1160
1161
1162
1163
                    else:
                        # Update for comprehensive logging even in simple case
                        assert previous_texts is not None
                        previous_texts[i] += delta_text
1164

1165
                    # set the previous values for the next iteration
1166
                    previous_num_tokens[i] += len(output.token_ids)
1167
1168
1169
1170
1171
1172

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
1173
1174
1175
1176
1177
1178
1179
                        # NOTE: If return_token_ids is enabled, we still need to
                        # send a chunk with token_ids even if delta_message is None
                        # to ensure all tokens are included in the response
                        if (
                            output.finish_reason is None
                            and not request.return_token_ids
                        ):
1180
                            continue
1181
                        delta_message = DeltaMessage()
1182

1183
1184
                    # Log streaming delta if output logging is enabled
                    if self.enable_log_outputs and self.request_logger:
1185
                        delta_content_parts = []
1186
                        if delta_message.content:
1187
1188
1189
1190
1191
1192
                            delta_content_parts.append(delta_message.content)
                        if delta_message.reasoning_content:
                            reasoning = delta_message.reasoning_content
                            delta_content_parts.append(f"[reasoning: {reasoning}]")
                        if delta_message.tool_calls:
                            tool_args = "".join(
1193
1194
                                tc.function.arguments
                                for tc in delta_message.tool_calls
1195
1196
                                if tc.function and tc.function.arguments
                            )
1197
1198
                            if tool_args:
                                delta_content_parts.append(f"[tool_calls: {tool_args}]")
1199

1200
1201
                        if delta_content_parts and self.enable_log_deltas:
                            delta_content = " ".join(delta_content_parts)
1202
1203
1204
                            self.request_logger.log_outputs(
                                request_id=request_id,
                                outputs=delta_content,
1205
                                output_token_ids=as_list(output.token_ids),
1206
1207
1208
1209
1210
                                finish_reason=output.finish_reason,
                                is_streaming=True,
                                delta=True,
                            )

1211
1212
1213
1214
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
1215
                            delta=delta_message,
1216
                            logprobs=logprobs,
1217
                            finish_reason=None,
1218
1219
1220
1221
1222
1223
                            token_ids=(
                                as_list(output.token_ids)
                                if request.return_token_ids
                                else None
                            ),
                        )
1224
1225

                    # if the model is finished generating
1226
                    else:
1227
1228
1229
1230
                        # check for error finish reason and abort streaming
                        # finish_reason='error' indicates a retryable error
                        self._raise_if_error(output.finish_reason, request_id)

1231
1232
1233
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
1234
                        # only happens if we are NOT using structured outputs
1235
                        auto_tools_called = False
1236
                        if tool_parser:
1237
1238
1239
1240
1241
1242
                            auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
                            index = (
                                len(tool_parser.prev_tool_call_arr) - 1
                                if auto_tools_called
                                else 0
                            )
1243
1244
1245
                        else:
                            index = 0

1246
1247
1248
1249
1250
1251
                        if (
                            self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output
                            )
                            and tool_parser
                        ):
1252
                            latest_delta_len = 0
1253
1254
                            if (
                                isinstance(
1255
                                    delta_message.tool_calls[0].function,
1256
1257
1258
1259
1260
                                    DeltaFunctionCall,
                                )
                            ) and isinstance(
                                delta_message.tool_calls[0].function.arguments, str
                            ):
1261
                                latest_delta_len = len(
1262
1263
                                    delta_message.tool_calls[0].function.arguments
                                )
1264

1265
1266
1267
1268
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
1269
1270
1271
1272
                                    "arguments", {}
                                ),
                                ensure_ascii=False,
                            )
1273

1274
                            # get what we've streamed so far for arguments
1275
                            # for the current tool
1276
1277
                            actual_call = tool_parser.streamed_args_for_tool[index]
                            if latest_delta_len > 0:
1278
                                actual_call = actual_call[:-latest_delta_len]
1279
1280

                            # check to see if there's anything left to stream
1281
                            remaining_call = expected_call.replace(actual_call, "", 1)
1282
                            # set that as a delta message
1283
1284
                            delta_message = self._create_remaining_args_delta(
                                delta_message, remaining_call, index
1285
                            )
1286

1287
                        # Send the finish response for each request.n only once
1288
1289
1290
1291
                        # In OpenAI's API, when a tool is called, the
                        # finish_reason is:
                        # "tool_calls" for "auto" or "required" tool calls,
                        # and "stop" for named tool calls.
1292
1293
                        if (
                            auto_tools_called
1294
                            or (tools_streamed[i] and not tool_choice_function_name)
1295
1296
                            or (self.use_harmony and harmony_tools_streamed[i])
                        ):
1297
1298
                            finish_reason_ = "tool_calls"
                        else:
1299
1300
1301
                            finish_reason_ = (
                                output.finish_reason if output.finish_reason else "stop"
                            )
1302
1303
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
1304
                            delta=delta_message,
1305
                            logprobs=logprobs,
1306
                            finish_reason=finish_reason_,
1307
                            stop_reason=output.stop_reason,
1308
1309
1310
1311
1312
1313
                            token_ids=(
                                as_list(output.token_ids)
                                if request.return_token_ids
                                else None
                            ),
                        )
1314

1315
                        finish_reason_sent[i] = True
1316

1317
                    choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
1318
1319
1320
1321
1322
                    chunk = ChatCompletionStreamResponse(
                        id=request_id,
                        object=chunk_object_type,
                        created=created_time,
                        choices=[choice_data],
1323
1324
                        model=model_name,
                    )
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334

                    # handle usage stats if requested & if continuous
                    if include_continuous_usage:
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=num_prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=num_prompt_tokens + completion_tokens,
                        )

1335
                    data = chunk.model_dump_json(exclude_unset=True)
1336
1337
                    yield f"data: {data}\n\n"

1338
1339
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
1340
1341
            if include_usage:
                completion_tokens = sum(previous_num_tokens)
1342
1343
1344
1345
1346
                final_usage = UsageInfo(
                    prompt_tokens=num_prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=num_prompt_tokens + completion_tokens,
                )
1347
1348
                if self.enable_prompt_tokens_details and num_cached_tokens:
                    final_usage.prompt_tokens_details = PromptTokenUsageInfo(
1349
1350
                        cached_tokens=num_cached_tokens
                    )
1351
1352
1353
1354
1355
1356
1357

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
1358
1359
1360
1361
1362
                    usage=final_usage,
                )
                final_usage_data = final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True
                )
1363
                yield f"data: {final_usage_data}\n\n"
1364

1365
1366
1367
1368
1369
            # report to FastAPI middleware aggregate usage across all choices
            num_completion_tokens = sum(previous_num_tokens)
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=num_completion_tokens,
1370
1371
1372
1373
1374
1375
1376
1377
1378
                total_tokens=num_prompt_tokens + num_completion_tokens,
            )

            # Log complete streaming response if output logging is enabled
            if self.enable_log_outputs and self.request_logger:
                # Log the complete response for each choice
                for i in range(num_choices):
                    full_text = (
                        previous_texts[i]
1379
1380
                        if previous_texts and i < len(previous_texts)
                        else f"<streaming_complete: {previous_num_tokens[i]} tokens>"
1381
1382
1383
1384
                    )
                    self.request_logger.log_outputs(
                        request_id=request_id,
                        outputs=full_text,
1385
                        output_token_ids=None,  # Consider also logging all token IDs
1386
1387
1388
1389
                        finish_reason="streaming_complete",
                        is_streaming=True,
                        delta=False,
                    )
1390

1391
1392
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
1393
        except Exception as e:
1394
            logger.exception("Error in chat completion stream generator.")
1395
            data = self.create_streaming_error_response(e)
1396
            yield f"data: {data}\n\n"
1397
1398
1399
1400
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
1401
1402
1403
1404
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
1405
        model_name: str,
1406
        conversation: list[ConversationMessage],
1407
        tokenizer: TokenizerLike,
1408
        request_metadata: RequestResponseMetadata,
1409
    ) -> ErrorResponse | ChatCompletionResponse:
1410
1411
        from vllm.tokenizers.mistral import MistralTokenizer

1412
        created_time = int(time.time())
1413
        final_res: RequestOutput | None = None
1414

1415
1416
1417
1418
1419
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
1420
        except ValueError as e:
1421
            return self.create_error_response(e)
1422

1423
1424
        assert final_res is not None

1425
        choices: list[ChatCompletionResponseChoice] = []
1426
        if self.tool_call_id_type == "kimi_k2":
1427
1428
1429
            history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
        else:
            history_tool_call_cnt = 0
1430

1431
1432
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
1433
1434
1435
            # check for error finish reason and raise GenerationError
            # finish_reason='error' indicates a retryable request-level internal error
            self._raise_if_error(output.finish_reason, request_id)
1436
            token_ids = output.token_ids
1437
            out_logprobs = output.logprobs
1438
            tool_call_info = None
1439

1440
1441
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
1442
                logprobs = self._create_chat_logprobs(
1443
                    token_ids=token_ids,
1444
                    top_logprobs=out_logprobs,
1445
                    num_output_top_logprobs=request.top_logprobs,
1446
                    tokenizer=tokenizer,
1447
                    return_as_token_id=request.return_tokens_as_token_ids,
1448
1449
1450
                )
            else:
                logprobs = None
1451
1452

            if self.use_harmony:
1453
                reasoning, content, _ = parse_chat_output(token_ids)
1454
                if not request.include_reasoning:
1455
                    reasoning = None
1456

1457
                if self.tool_parser is not None:
1458
1459
1460
1461
1462
                    if tokenizer is None:
                        raise ValueError(
                            "Tokenizer not available when `skip_tokenizer_init=True`"
                        )

1463
1464
1465
1466
1467
1468
1469
                    tool_parser = self.tool_parser(tokenizer)
                    # NOTE: We use token_ids for openai tool parser
                    tool_call_info = tool_parser.extract_tool_calls(
                        "",
                        request=request,
                        token_ids=token_ids,  # type: ignore
                    )
1470
                    content = tool_call_info.content
1471
1472
                    message = ChatMessage(
                        role=role,
1473
                        reasoning=reasoning,
1474
1475
1476
1477
1478
1479
                        content=content,
                        tool_calls=tool_call_info.tool_calls,
                    )
                else:
                    message = ChatMessage(
                        role=role,
1480
                        reasoning=reasoning,
1481
1482
                        content=content,
                    )
1483
1484
1485
1486
1487

                choice_data = ChatCompletionResponseChoice(
                    index=output.index,
                    message=message,
                    logprobs=logprobs,
1488
1489
1490
1491
1492
1493
1494
                    finish_reason=(
                        "tool_calls"
                        if (tool_call_info is not None and tool_call_info.tools_called)
                        else output.finish_reason
                        if output.finish_reason
                        else "stop"
                    ),
1495
                    stop_reason=output.stop_reason,
1496
1497
1498
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
1499
1500
1501
                )
                choices.append(choice_data)
                continue
1502

1503
            if self.reasoning_parser:
1504
                try:
1505
1506
1507
1508
1509
                    if tokenizer is None:
                        raise ValueError(
                            "Tokenizer not available when `skip_tokenizer_init=True`"
                        )

1510
1511
1512
1513
1514
                    # Pass the same chat template kwargs as used in tokenization
                    chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
                        request.chat_template_kwargs,
                        self.default_chat_template_kwargs,
                    )
1515
1516
                    reasoning_parser = self.reasoning_parser(
                        tokenizer,
1517
                        chat_template_kwargs=chat_template_kwargs,  # type: ignore[call-arg]
1518
                    )
1519
1520
1521
                except RuntimeError as e:
                    logger.exception("Error in reasoning parser creation.")
                    return self.create_error_response(str(e))
1522
1523
                # If the reasoning parser is enabled,
                # tool calls are extracted exclusively from the content.
1524
                reasoning, content = reasoning_parser.extract_reasoning(
1525
1526
                    output.text, request=request
                )
1527
                if not request.include_reasoning:
1528
                    reasoning = None
1529
            else:
1530
                reasoning = None
1531
                content = output.text
1532

1533
            auto_tools_called = False
1534
1535
            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
            tool_calls, content = self._parse_tool_calls_from_content(
                request=request,
                tokenizer=tokenizer,
                content=content,
                enable_auto_tools=self.enable_auto_tools,
                tool_parser_cls=self.tool_parser,
            )
            tool_call_class = (
                MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
            )
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
            if self.use_harmony:
                # Harmony models already have parsed content and tool_calls
                # through parse_chat_output. Respect its output directly.
                message = ChatMessage(
                    role=role,
                    reasoning=reasoning,
                    content=content,
                    tool_calls=tool_calls if tool_calls else [],
                )

            elif (not self.enable_auto_tools or not self.tool_parser) and (
1557
1558
1559
                not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
                and request.tool_choice != "required"
            ):
1560
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1561

1562
1563
1564
1565
            elif (
                request.tool_choice
                and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
            ):
1566
                assert tool_calls is not None and len(tool_calls) > 0
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
                tool_call_class_items = []
                for idx, tc in enumerate(tool_calls):
                    # Use native ID if available (e.g., Kimi K2),
                    # otherwise generate ID with correct id_type
                    if tc.id:
                        tool_call_class_items.append(
                            tool_call_class(id=tc.id, function=tc)
                        )
                    else:
                        # Generate ID using the correct format (kimi_k2 or random),
                        # but leave it to the class if it's Mistral to preserve
                        # 9-char IDs
                        if isinstance(tokenizer, MistralTokenizer):
                            tool_call_class_items.append(tool_call_class(function=tc))
                        else:
                            generated_id = make_tool_call_id(
                                id_type=self.tool_call_id_type,
                                func_name=tc.name,
                                idx=history_tool_call_cnt + idx,
                            )
                            tool_call_class_items.append(
                                tool_call_class(id=generated_id, function=tc)
                            )
                    history_tool_call_cnt += 1
1591
1592
                message = ChatMessage(
                    role=role,
1593
                    reasoning=reasoning,
1594
                    content="",
1595
                    tool_calls=tool_call_class_items,
1596
                )
1597

1598
            elif request.tool_choice and request.tool_choice == "required":
1599
1600
                tool_call_class_items = []
                assert tool_calls is not None and len(tool_calls) > 0
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
                for idx, tool_call in enumerate(tool_calls):
                    # Use native ID if available,
                    # otherwise generate ID with correct id_type
                    if tool_call.id:
                        tool_call_class_items.append(
                            tool_call_class(id=tool_call.id, function=tool_call)
                        )
                    else:
                        # Generate ID using the correct format (kimi_k2 or random),
                        # but leave it to the class if it's Mistral to preserve
                        # 9-char IDs
                        if isinstance(tokenizer, MistralTokenizer):
                            tool_call_class_items.append(
                                tool_call_class(function=tool_call)
                            )
                        else:
                            generated_id = make_tool_call_id(
1618
1619
                                id_type=self.tool_call_id_type,
                                func_name=tool_call.name,
1620
1621
1622
1623
1624
                                idx=history_tool_call_cnt + idx,
                            )
                            tool_call_class_items.append(
                                tool_call_class(id=generated_id, function=tool_call)
                            )
1625
                    history_tool_call_cnt += 1
1626
1627
1628
                message = ChatMessage(
                    role=role,
                    content="",
1629
                    tool_calls=tool_call_class_items,
1630
                    reasoning=reasoning,
1631
                )
1632

1633
1634
            # if the request doesn't use tool choice
            # OR specifies to not use a tool
1635
            elif not request.tool_choice or request.tool_choice == "none":
1636
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1637
1638

            # handle when there are tools and tool choice is auto
1639
1640
1641
1642
1643
1644
            elif (
                request.tools
                and (request.tool_choice == "auto" or request.tool_choice is None)
                and self.enable_auto_tools
                and self.tool_parser
            ):
1645
1646
1647
                # In the OpenAI API the finish_reason is "tools_called"
                # if the tool choice is auto and the model produced a tool
                # call. The same is not true for named function calls
1648
1649
                auto_tools_called = tool_calls is not None and len(tool_calls) > 0
                if tool_calls:
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
                    tool_call_items = []
                    for idx, tc in enumerate(tool_calls):
                        # Use native ID if available (e.g., Kimi K2),
                        # otherwise generate ID with correct id_type
                        if tc.id:
                            tool_call_items.append(
                                tool_call_class(id=tc.id, function=tc)
                            )
                        else:
                            # Generate ID using the correct format (kimi_k2 or random),
                            # but leave it to the class if it's Mistral to preserve
                            # 9-char IDs
                            if isinstance(tokenizer, MistralTokenizer):
                                tool_call_items.append(tool_call_class(function=tc))
                            else:
                                generated_id = make_tool_call_id(
                                    id_type=self.tool_call_id_type,
                                    func_name=tc.name,
                                    idx=history_tool_call_cnt + idx,
                                )
                                tool_call_items.append(
                                    tool_call_class(id=generated_id, function=tc)
                                )
                        history_tool_call_cnt += 1
1674
1675
                    message = ChatMessage(
                        role=role,
1676
                        reasoning=reasoning,
1677
                        content=content,
1678
                        tool_calls=tool_call_items,
1679
                    )
1680
1681
1682
1683

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
1684
1685
1686
1687
                    ret_content = content

                    # try to use content return from tool parser first,
                    # tool parser may do some modify for the content.
1688
1689
                    if content and len(content) > 0:
                        ret_content = content
1690
1691
                    message = ChatMessage(
                        role=role,
1692
                        reasoning=reasoning,
1693
1694
                        content=ret_content,
                    )
1695
1696
1697
1698
1699
1700

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
1701
1702
                    "completion."
                )
1703
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1704
1705
1706
1707
1708
1709
1710
1711
            # In OpenAI's API, when a tool is called, the finish_reason is:
            # "tool_calls" for "auto" or "required" tool calls,
            # and "stop" for named tool calls.
            is_finish_reason_tool_calls = auto_tools_called or (
                request.tool_choice
                and request.tool_choice == "required"
                and output.finish_reason == "stop"
            )
1712

1713
1714
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
1715
                message=message,
1716
                logprobs=logprobs,
1717
1718
1719
1720
1721
                finish_reason="tool_calls"
                if is_finish_reason_tool_calls
                else output.finish_reason
                if output.finish_reason
                else "stop",
1722
                stop_reason=output.stop_reason,
1723
1724
1725
                token_ids=(
                    as_list(output.token_ids) if request.return_token_ids else None
                ),
1726
            )
1727
            choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
1728

1729
1730
            choices.append(choice_data)

1731
        if request.echo:
1732
            last_msg_content: str | list[dict[str, str]] = ""
1733
1734
1735
1736
1737
            if (
                conversation
                and "content" in conversation[-1]
                and conversation[-1].get("role") == role
            ):
1738
                last_msg_content = conversation[-1]["content"] or ""
1739
            if isinstance(last_msg_content, list):
1740
                last_msg_content = "\n".join(msg["text"] for msg in last_msg_content)
1741
1742

            for choice in choices:
1743
                full_message = last_msg_content + (choice.message.content or "")
1744
1745
                choice.message.content = full_message

1746
        assert final_res.prompt_token_ids is not None
1747
        num_prompt_tokens = len(final_res.prompt_token_ids)
1748
1749
        if final_res.encoder_prompt_token_ids is not None:
            num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
1750
        num_generated_tokens = sum(
1751
1752
1753
1754
1755
1756
1757
            len(output.token_ids) for output in final_res.outputs
        )
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
1758
1759
        if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
            usage.prompt_tokens_details = PromptTokenUsageInfo(
1760
1761
                cached_tokens=final_res.num_cached_tokens
            )
1762
1763
1764

        request_metadata.final_usage_info = usage

1765
1766
1767
1768
1769
1770
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
1771
            prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
1772
1773
1774
            prompt_token_ids=(
                final_res.prompt_token_ids if request.return_token_ids else None
            ),
Robert Shaw's avatar
Robert Shaw committed
1775
            kv_transfer_params=final_res.kv_transfer_params,
1776
1777
        )

1778
1779
1780
1781
1782
1783
1784
1785
1786
        # Log complete response if output logging is enabled
        if self.enable_log_outputs and self.request_logger:
            for choice in choices:
                output_text = ""
                if choice.message.content:
                    output_text = choice.message.content
                elif choice.message.tool_calls:
                    # For tool calls, log the function name and arguments
                    tool_call_descriptions = []
1787
1788
1789
1790
1791
                    for tc in choice.message.tool_calls:  # type: ignore
                        function_call: FunctionCall = tc.function  # type: ignore
                        tool_call_descriptions.append(
                            f"{function_call.name}({function_call.arguments})"
                        )
1792
1793
1794
1795
1796
1797
1798
                    tool_calls_str = ", ".join(tool_call_descriptions)
                    output_text = f"[tool_calls: {tool_calls_str}]"

                if output_text:
                    # Get the corresponding output token IDs
                    output_token_ids = None
                    if choice.index < len(final_res.outputs):
1799
                        output_token_ids = final_res.outputs[choice.index].token_ids
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809

                    self.request_logger.log_outputs(
                        request_id=request_id,
                        outputs=output_text,
                        output_token_ids=output_token_ids,
                        finish_reason=choice.finish_reason,
                        is_streaming=False,
                        delta=False,
                    )

1810
        return response
1811
1812

    def _get_top_logprobs(
1813
1814
        self,
        logprobs: dict[int, Logprob],
1815
        top_logprobs: int | None,
1816
        tokenizer: TokenizerLike | None,
1817
1818
        should_return_as_token_id: bool,
    ) -> list[ChatCompletionLogProb]:
1819
        return [
1820
            ChatCompletionLogProb(
1821
1822
1823
1824
1825
1826
1827
1828
                token=(
                    token := self._get_decoded_token(
                        p[1],
                        p[0],
                        tokenizer,
                        return_as_token_id=should_return_as_token_id,
                    )
                ),
1829
1830
                logprob=max(p[1].logprob, -9999.0),
                bytes=list(token.encode("utf-8", errors="replace")),
1831
1832
            )
            for i, p in enumerate(logprobs.items())
1833
            if (top_logprobs and i < top_logprobs or top_logprobs == -1)
1834
1835
1836
1837
1838
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
1839
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
1840
        tokenizer: TokenizerLike | None,
1841
1842
        num_output_top_logprobs: int | None = None,
        return_as_token_id: bool | None = None,
1843
1844
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
1845
        logprobs_content: list[ChatCompletionLogProbsContent] = []
1846

1847
1848
1849
1850
1851
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
1852
1853
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
1854
            if step_top_logprobs is None or step_top_logprobs.get(token_id) is None:
1855
                if should_return_as_token_id:
1856
                    token = f"token_id:{token_id}"
1857
                else:
1858
1859
                    if tokenizer is None:
                        raise ValueError(
1860
                            "Unable to get tokenizer because `skip_tokenizer_init=True`"
1861
1862
                        )

1863
                    token = tokenizer.decode(token_id)
1864

1865
1866
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
1867
                        token=token,
1868
                        bytes=list(token.encode("utf-8", errors="replace")),
1869
1870
                    )
                )
1871
            else:
1872
1873
1874
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

1875
1876
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
1877
                        token=self._get_decoded_token(
1878
1879
1880
                            step_token,
                            token_id,
                            tokenizer,
1881
                            should_return_as_token_id,
1882
1883
                        ),
                        logprob=max(step_token.logprob, -9999.0),
1884
1885
1886
1887
1888
                        bytes=(
                            None
                            if step_decoded is None
                            else list(step_decoded.encode("utf-8", errors="replace"))
                        ),
1889
                        top_logprobs=self._get_top_logprobs(
1890
1891
1892
1893
1894
1895
1896
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                            should_return_as_token_id,
                        ),
                    )
                )
1897
1898

        return ChatCompletionLogProbs(content=logprobs_content)
1899

1900
    def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest):
1901
1902
1903
1904
1905
1906
1907
1908
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
1909
1910
1911
1912
1913
1914
        return (
            request.tools
            and self.tool_parser
            and self.enable_auto_tools
            and request.tool_choice in ["auto", None]
        )
1915
1916
1917

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
1918
        delta_message: DeltaMessage | None,
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
1930
            output.finish_reason is not None
1931
1932
1933
1934
1935
            and self.enable_auto_tools
            and self.tool_parser
            and delta_message
            and delta_message.tool_calls
            and delta_message.tool_calls[0]
1936
1937
1938
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
        )
1939

1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
    @staticmethod
    def _create_remaining_args_delta(
        delta_message: DeltaMessage,
        remaining_call: str,
        index: int,
    ) -> DeltaMessage:
        """
        Create a delta message for remaining tool arguments, preserving
        id/type/name from the original delta.
        """
        original_tc = next(
            (tc for tc in delta_message.tool_calls if tc.index == index),
            None,
        )
        original_fn = original_tc.function if original_tc else None
        return DeltaMessage(
            tool_calls=[
                DeltaToolCall(
                    index=index,
                    id=original_tc.id if original_tc else None,
                    type=original_tc.type if original_tc else None,
                    function=DeltaFunctionCall(
                        name=original_fn.name if original_fn else None,
                        arguments=remaining_call,
                    ),
                )
            ]
        )

1969
1970
1971
    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
1972
        should_include_tools: bool = True,
1973
1974
1975
    ):
        messages: list[OpenAIMessage] = []

1976
1977
1978
        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
1979
        maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
1980

1981
1982
1983
1984
1985
1986
1987
1988
        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
        sys_msg = get_system_message(
            reasoning_effort=request.reasoning_effort,
            browser_description=None,
1989
            python_description=None,
1990
            with_custom_tools=should_include_tools,
1991
        )
1992
1993
1994
        messages.append(sys_msg)

        # Add developer message.
1995
1996
        if request.tools:
            dev_msg = get_developer_message(
1997
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
1998
1999
            )
            messages.append(dev_msg)
2000
2001

        # Add user message.
2002
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
2003
2004
2005

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
2006
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
2007
2008
2009
2010
2011

        # Add cache_salt if provided in the request
        if request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

2012
        return messages, [engine_prompt]