serving.py 86.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import time
7
8
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
9
from typing import Any, Final
10

11
import jinja2
12
import partial_json_parser
13
import regex as re
14
from fastapi import Request
15
from openai_harmony import Message as OpenAIMessage
16
from partial_json_parser.core.options import Allow
17

18
from vllm.engine.protocol import EngineClient
19
20
21
22
23
24
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
    ConversationMessage,
    get_history_tool_calls_cnt,
    make_tool_call_id,
)
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.chat_completion.protocol import (
27
28
29
30
31
32
33
34
35
36
    ChatCompletionLogProb,
    ChatCompletionLogProbs,
    ChatCompletionLogProbsContent,
    ChatCompletionNamedToolChoiceParam,
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
37
38
)
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
39
    TokenState,
40
41
42
    extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
43
44
45
46
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ErrorResponse,
47
    FunctionCall,
48
49
50
51
52
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    ToolCall,
    UsageInfo,
)
53
from vllm.entrypoints.openai.engine.serving import (
54
55
56
57
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
58
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
59
60
61
62
63
64
65
66
67
from vllm.entrypoints.openai.parser.harmony_utils import (
    get_developer_message,
    get_stop_tokens_for_assistant_actions,
    get_streamable_parser_for_assistant,
    get_system_message,
    parse_chat_inputs_to_harmony_messages,
    parse_chat_output,
    render_for_completion,
)
68
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
69
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
70
from vllm.inputs.data import ProcessorInputs, TokensPrompt
71
from vllm.logger import init_logger
72
from vllm.logprobs import Logprob
73
from vllm.outputs import CompletionOutput, RequestOutput
74
from vllm.parser import ParserManager
75
from vllm.reasoning import ReasoningParser
76
from vllm.sampling_params import BeamSearchParams, SamplingParams
77
from vllm.tokenizers import TokenizerLike
78
79
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
80
from vllm.tool_parsers.utils import partial_json_loads
81
from vllm.utils.collection_utils import as_list
82
83
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt
84
85
86
87
88

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):
89
90
91
    def __init__(
        self,
        engine_client: EngineClient,
92
        models: OpenAIServingModels,
93
94
        response_role: str,
        *,
95
96
        request_logger: RequestLogger | None,
        chat_template: str | None,
97
        chat_template_content_format: ChatTemplateContentFormatOption,
98
        trust_request_chat_template: bool = False,
99
        return_tokens_as_token_ids: bool = False,
100
        reasoning_parser: str = "",
101
        enable_auto_tools: bool = False,
102
        exclude_tools_when_tool_choice_none: bool = False,
103
        tool_parser: str | None = None,
104
        enable_prompt_tokens_details: bool = False,
105
        enable_force_include_usage: bool = False,
106
        enable_log_outputs: bool = False,
107
        enable_log_deltas: bool = True,
108
        log_error_stack: bool = False,
109
        default_chat_template_kwargs: dict[str, Any] | None = None,
110
    ) -> None:
111
112
113
114
115
116
117
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
            log_error_stack=log_error_stack,
        )
118

119
        self.response_role = response_role
120
121
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
122
        self.trust_request_chat_template = trust_request_chat_template
123
        self.default_chat_template_kwargs = default_chat_template_kwargs or {}
124
        self.enable_log_outputs = enable_log_outputs
125
        self.enable_log_deltas = enable_log_deltas
126

127
        # set up reasoning parser
128
        self.reasoning_parser_cls = ParserManager.get_reasoning_parser(
129
130
            reasoning_parser_name=reasoning_parser
        )
131
132
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
133
134
135
136
        self.tool_parser = ParserManager.get_tool_parser(
            tool_parser_name=tool_parser,
            enable_auto_tools=enable_auto_tools,
            model_name=self.model_config.model,
137
138
        )
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
139

140
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
141
        self.enable_force_include_usage = enable_force_include_usage
142
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
143
144
145
146
147
148
        mc = self.model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )
149
        self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
150
151
152
153
        if self.use_harmony:
            if "stop_token_ids" not in self.default_sampling_params:
                self.default_sampling_params["stop_token_ids"] = []
            self.default_sampling_params["stop_token_ids"].extend(
154
155
                get_stop_tokens_for_assistant_actions()
            )
156

157
158
159
160
161
162
163
164
165
166
        # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
        hf_overrides = getattr(self.model_config, "hf_overrides", None)
        if self.model_config.hf_text_config.model_type == "kimi_k2" or (
            isinstance(hf_overrides, dict)
            and hf_overrides.get("model_type") == "kimi_k2"
        ):
            self.tool_call_id_type = "kimi_k2"
        else:
            self.tool_call_id_type = "random"

167
168
169
170
171
172
173
174
175
176
        # NOTE(woosuk): While OpenAI's chat completion API supports browsing
        # for some models, currently vLLM doesn't support it. Please use the
        # Responses API instead.
        self.supports_browsing = False
        self.browser_tool = None
        # NOTE(woosuk): Chat completion API does not support code interpreter.
        # Please use the Responses API instead.
        self.supports_code_interpreter = False
        self.python_tool = None

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    async def warmup(self) -> None:
        """
        Warm up the chat template processing to avoid first-request latency.

        This method triggers Jinja2 template compilation and content format
        detection that would otherwise happen on the first real request,
        causing increased latency on the first request.
        """
        logger.info("Warming up chat template processing...")
        start_time = time.perf_counter()

        try:
            # Create a minimal dummy request
            dummy_request = ChatCompletionRequest(
                messages=[{"role": "user", "content": "warmup"}],
                model=None,
                max_completion_tokens=1,
            )

            # Call _preprocess_chat to trigger template compilation
            # This forces:
            # 1. Chat template content format detection
            # 2. Jinja2 template compilation
            # 3. Tokenizer initialization for chat
            await self._preprocess_chat(
                dummy_request,
                dummy_request.messages,
204
205
206
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=self.default_chat_template_kwargs,
207
208
209
210
211
212
213
214
215
            )

            elapsed = (time.perf_counter() - start_time) * 1000
            logger.info("Chat template warmup completed in %.1fms", elapsed)

        except Exception:
            # Log but don't fail server startup if warmup fails
            logger.exception("Chat template warmup failed")

216
    async def render_chat_request(
217
218
        self,
        request: ChatCompletionRequest,
219
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
220
        """
221
        render chat request by validating and preprocessing inputs.
222

223
224
225
        Returns:
            A tuple of (conversation, engine_prompts) on success,
            or an ErrorResponse on failure.
226
227
228
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
229
            logger.error("Error with model %s", error_check_ret)
230
231
            return error_check_ret

232
233
234
235
236
237
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

238
        try:
239
            tokenizer = self.renderer.tokenizer
240

241
242
            tool_parser = self.tool_parser

243
            if is_mistral_tokenizer(tokenizer):
244
245
246
                # because of issues with pydantic we need to potentially
                # re-serialize the tool_calls field of the request
                # for more info: see comment in `maybe_serialize_tool_calls`
247
248
249
                _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
                _mt.truncate_tool_call_ids(request)  # type: ignore[arg-type]
                _mt.validate_request_params(request)
250

251
252
253
            # Check if tool parsing is unavailable (common condition)
            tool_parsing_unavailable = (
                tool_parser is None
254
                and not is_mistral_tokenizer(tokenizer)
255
                and not self.use_harmony
256
257
258
259
260
261
            )

            # Validate tool_choice when tool parsing is required but unavailable
            if tool_parsing_unavailable and request.tool_choice not in (
                None,
                "none",
262
            ):
263
264
265
266
267
268
269
270
271
272
273
274
275
                if request.tool_choice == "auto" and not self.enable_auto_tools:
                    # for hf tokenizers, "auto" tools requires
                    # --enable-auto-tool-choice and --tool-call-parser
                    return self.create_error_response(
                        '"auto" tool choice requires '
                        "--enable-auto-tool-choice and --tool-call-parser to be set"
                    )
                elif request.tool_choice != "auto":
                    # "required" or named tool requires tool parser
                    return self.create_error_response(
                        f'tool_choice="{request.tool_choice}" requires '
                        "--tool-call-parser to be set"
                    )
276

277
278
279
280
            if request.tools is None or (
                request.tool_choice == "none"
                and self.exclude_tools_when_tool_choice_none
            ):
281
282
283
                tool_dicts = None
            else:
                tool_dicts = [tool.model_dump() for tool in request.tools]
284

285
286
            if not self.use_harmony:
                # Common case.
287
288
289
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
290
                    trust_request_chat_template=self.trust_request_chat_template,
291
292
293
                )
                if error_check_ret is not None:
                    return error_check_ret
294

295
                conversation, engine_prompts = await self._preprocess_chat(
296
297
                    request,
                    request.messages,
298
299
300
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=self.default_chat_template_kwargs,
301
302
303
304
305
                    tool_dicts=tool_dicts,
                    tool_parser=tool_parser,
                )
            else:
                # For GPT-OSS.
306
307
308
309
                should_include_tools = tool_dicts is not None
                conversation, engine_prompts = self._make_request_with_harmony(
                    request, should_include_tools
                )
310
        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
311
            logger.exception("Error in preprocessing prompt inputs")
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            return self.create_error_response(e)

        return conversation, engine_prompts

    async def create_chat_completion(
        self,
        request: ChatCompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
        """
        Chat Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        Chat Completion API.
        """
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        # Streaming response
        tokenizer = self.renderer.tokenizer
        assert tokenizer is not None
        reasoning_parser: ReasoningParser | None = None
        try:
            if self.reasoning_parser_cls:
                # Pass the same chat template kwargs as used in tokenization
                chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
                    request.chat_template_kwargs,
                    self.default_chat_template_kwargs,
                )
                reasoning_parser = self.reasoning_parser_cls(
                    tokenizer,
                    chat_template_kwargs=chat_template_kwargs,  # type: ignore[call-arg]
                )
        except RuntimeError as e:
            logger.exception("Error in reasoning parser creation.")
            return self.create_error_response(str(e))
346
347
348
349
350
        result = await self.render_chat_request(request)
        if isinstance(result, ErrorResponse):
            return result

        conversation, engine_prompts = result
351

352
353
354
        request_id = (
            f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
        )
355
356
357
358
359

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

360
361
362
363
364
365
366
367
368
369
        try:
            lora_request = self._maybe_get_adapters(
                request, supports_default_mm_loras=True
            )

            model_name = self.models.model_name(lora_request)
        except (ValueError, TypeError, RuntimeError) as e:
            logger.exception("Error preparing request components")
            return self.create_error_response(e)

370
371
372
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

373
        # Schedule the request and get the result generator.
374
        max_model_len = self.model_config.max_model_len
375
        generators: list[AsyncGenerator[RequestOutput, None]] = []
376
        try:
377
            for i, engine_prompt in enumerate(engine_prompts):
378
379
380
                prompt_token_ids = self._extract_prompt_components(
                    engine_prompt
                ).token_ids
381

382
383
384
385
386
                # If we are creating sub requests for multiple prompts, ensure that they
                # have unique request ids.
                sub_request_id = (
                    request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
                )
387
388

                max_tokens = get_max_tokens(
389
                    max_model_len,
390
391
392
                    request.max_completion_tokens
                    if request.max_completion_tokens is not None
                    else request.max_tokens,
393
394
                    self._extract_prompt_len(engine_prompt),
                    self.default_sampling_params,
395
                    self.override_max_tokens,
396
                )
397

398
                sampling_params: SamplingParams | BeamSearchParams
399
400
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
401
402
                        max_tokens, self.default_sampling_params
                    )
403
404
                else:
                    sampling_params = request.to_sampling_params(
405
406
407
                        max_tokens,
                        self.default_sampling_params,
                    )
408

409
                self._log_inputs(
410
                    sub_request_id,
411
                    engine_prompt,
412
413
414
                    params=sampling_params,
                    lora_request=lora_request,
                )
415

416
417
418
419
420
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
421
422

                if isinstance(sampling_params, BeamSearchParams):
423
                    generator = self.beam_search(
424
                        prompt=engine_prompt,
425
                        request_id=sub_request_id,
426
                        params=sampling_params,
427
                        lora_request=lora_request,
428
                        trace_headers=trace_headers,
429
430
                    )
                else:
431
432
433
434
                    reasoning_ended = (
                        reasoning_parser.is_reasoning_end(prompt_token_ids or [])
                        if reasoning_parser
                        else None
435
                    )
436

437
                    generator = self.engine_client.generate(
438
                        engine_prompt,
439
                        sampling_params,
440
                        sub_request_id,
441
442
443
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
444
                        data_parallel_rank=data_parallel_rank,
445
                        reasoning_ended=reasoning_ended,
446
447
448
                    )

                generators.append(generator)
449
        except ValueError as e:
450
            return self.create_error_response(e)
451

452
        assert len(generators) == 1
453
        (result_generator,) = generators
454

455
456
        if request.stream:
            return self.chat_completion_stream_generator(
457
458
459
460
461
462
463
                request,
                result_generator,
                request_id,
                model_name,
                conversation,
                tokenizer,
                request_metadata,
464
                reasoning_parser,
465
            )
466

467
468
        try:
            return await self.chat_completion_full_generator(
469
470
471
472
473
474
475
                request,
                result_generator,
                request_id,
                model_name,
                conversation,
                tokenizer,
                request_metadata,
476
                reasoning_parser,
477
            )
478
479
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
480
        except ValueError as e:
481
            return self.create_error_response(e)
482
483
484
485

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
486
        return request.messages[-1]["role"]
487

488
    @staticmethod
489
    def _bracket_level(s: str, opening="{", closing="}") -> int:
490
491
492
493
494
495
496
497
498
499
500
501
        """
        Calculate the current level of nested brackets in a given string.
        """
        level = 0
        for char in s:
            if char == opening:
                level += 1
            elif char == closing:
                level -= 1
        return level

    @staticmethod
502
    def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]:
503
504
505
506
507
508
509
510
511
        # remove last '},' of the tool definition stemming from the
        # "name"/"parameters" outer object or closing ']' of the tool list
        # count occurrences of opening and closing curly braces and
        # once level 0 is reached stop outputting text
        # if 0 is reached while parsing the delta_text we know the current
        # tool will finish in this current iteration
        bracket_level = OpenAIServingChat._bracket_level(previous_text)
        updated_delta, passed_zero = "", False
        for c in delta_text:
512
            if c == "{":
513
514
                bracket_level += 1
                passed_zero = bracket_level == 0
515
            elif c == "}":
516
517
518
519
520
521
522
                bracket_level -= 1
                passed_zero = bracket_level == 0

            if bracket_level != 0:
                updated_delta += c
            else:
                # if a comma is reached at level 0 we can stop
523
                if c == ",":
524
525
526
527
528
529
                    break
        return updated_delta, passed_zero

    def extract_tool_call_required_streaming(
        self,
        previous_text: str,
530
        current_text: str | None,
531
532
        delta_text: str,
        function_name_returned: bool,
533
534
        tool_call_idx: int | None = None,
    ) -> tuple[DeltaMessage | None, bool]:
535
536
537
        if current_text is None or current_text == "":
            # if the current text is empty, we cannot parse it
            return None, function_name_returned
538
        try:
539
540
541
542
543
544
            flags = Allow.ALL
            obj, _ = partial_json_loads(current_text, flags)
        except (
            partial_json_parser.core.exceptions.MalformedJSON,
            json.JSONDecodeError,
        ):
545
            logger.debug("not enough tokens to parse into JSON yet")
546
547
548
549
550
551
552
553
554
555
            obj = None

        # check if the current text is a valid array
        # containing a partial tool calling object
        # if not repeat
        if obj is None or not isinstance(obj, list) or not len(obj) > 0:
            function_name_returned = False
            delta_message = None
        else:
            _, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
556
557
                delta_text, previous_text
            )
558
559
560
561
            # take the last tool call from the generated list
            current_tool_call = obj[-1]

            # once parameters have been generated the name is complete as well
562
563
564
            if not finishes_previous_tool and (
                "name" not in current_tool_call or "parameters" not in current_tool_call
            ):
565
566
567
568
569
                function_name_returned = False
                delta_message = None
            else:
                if not function_name_returned:
                    # get partly generated arguments from the latest tool call
570
571
572
                    param_match = re.search(
                        r'.*"parameters":\s*(.*)', current_text, re.DOTALL
                    )
573
574
                    arguments = param_match.group(1) if param_match else ""
                    arguments, _ = OpenAIServingChat._filter_delta_text(
575
576
                        arguments, previous_text
                    )
577
578
579
580

                    # if this iteration finishes a previous tool call but a
                    # new incomplete tool is already generated, take the
                    # previous from the list
581
                    if finishes_previous_tool and "parameters" not in current_tool_call:
582
583
584
                        current_tool_call = obj[-2]

                    function_name_returned = True
585
586
587
                    tool_call_id = make_tool_call_id(
                        id_type=self.tool_call_id_type,
                        func_name=current_tool_call["name"],
588
589
590
591
592
593
594
595
596
597
598
599
600
601
                        idx=tool_call_idx,
                    )
                    delta_message = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                id=tool_call_id,
                                function=DeltaFunctionCall(
                                    name=current_tool_call["name"], arguments=arguments
                                ),
                                index=len(obj) - 1,
                                type="function",
                            )
                        ]
                    )
602
603
604

                else:
                    delta_text, _ = OpenAIServingChat._filter_delta_text(
605
606
                        delta_text, previous_text
                    )
607
608

                    if delta_text != "":
609
610
611
612
613
614
615
616
617
618
619
620
621
                        delta_message = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    function=DeltaFunctionCall(
                                        # OpenAI API returns None
                                        # instead of name every time
                                        name=None,
                                        arguments=delta_text,
                                    ),
                                    index=len(obj) - 1,
                                )
                            ]
                        )
622
623
624
625
626
                    else:
                        delta_message = None

        return delta_message, function_name_returned

627
    async def chat_completion_stream_generator(
628
629
630
631
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
632
        model_name: str,
633
        conversation: list[ConversationMessage],
634
        tokenizer: TokenizerLike,
635
        request_metadata: RequestResponseMetadata,
636
        reasoning_parser: ReasoningParser | None = None,
637
    ) -> AsyncGenerator[str, None]:
638
        created_time = int(time.time())
639
        chunk_object_type: Final = "chat.completion.chunk"
640
        first_iteration = True
641
642

        # Send response for each token for each request.n (index)
643
644
645
        num_choices = 1 if request.n is None else request.n
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices
646
        num_prompt_tokens = 0
647
        num_cached_tokens = None
648
649
        if self.use_harmony:
            harmony_parsers = [
650
                get_streamable_parser_for_assistant() for _ in range(num_choices)
651
            ]
652
653
            harmony_tools_streamed = [False] * num_choices
        tools_streamed = [False] * num_choices
654
655
656
657
658
659
660
661
662

        if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
            tool_choice_function_name = request.tool_choice.function.name
        else:
            tool_choice_function_name = None

        # Determine whether tools are in use with "auto" tool choice
        tool_choice_auto = (
            not tool_choice_function_name
663
664
            and self._should_stream_with_auto_tool_parsing(request)
        )
665

666
        all_previous_token_ids: list[list[int]] | None
667
        function_name_returned = [False] * num_choices
668
        if self.tool_call_id_type == "kimi_k2":
669
670
671
            history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
        else:
            history_tool_call_cnt = 0
672

673
674
675
        # Always track previous_texts for comprehensive output logging
        previous_texts = [""] * num_choices

676
677
        # Only one of these will be used, thus previous_texts and
        # all_previous_token_ids will not be used twice in the same iteration.
678
        if tool_choice_auto or reasoning_parser:
679
680
            # These are only required in "auto" tool choice case
            all_previous_token_ids = [[]] * num_choices
681
682
683
            # For reasoning parser and tool call all enabled
            added_content_delta_arr = [False] * num_choices
            reasoning_end_arr = [False] * num_choices
684
            prompt_is_reasoning_end_arr: list[bool | None] = [None] * num_choices
685
        else:
686
            all_previous_token_ids = None
687

688
689
690
        # Prepare the tool parser if it's needed
        try:
            if tool_choice_auto and self.tool_parser:
691
692
693
694
695
                if tokenizer is None:
                    raise ValueError(
                        "Tokenizer not available when `skip_tokenizer_init=True`"
                    )

696
                tool_parsers: list[ToolParser | None] = [
697
698
699
700
                    self.tool_parser(tokenizer)
                ] * num_choices
            else:
                tool_parsers = [None] * num_choices
701
        except Exception as e:
702
            logger.exception("Error in tool parser creation.")
703
            data = self.create_streaming_error_response(e)
704
705
706
707
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return

708
        stream_options = request.stream_options
709
710
711
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
712

713
714
        try:
            async for res in result_generator:
715
716
                if res.prompt_token_ids is not None:
                    num_prompt_tokens = len(res.prompt_token_ids)
717
718
                    if res.encoder_prompt_token_ids is not None:
                        num_prompt_tokens += len(res.encoder_prompt_token_ids)
719

720
721
722
723
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
724
                    num_cached_tokens = res.num_cached_tokens
725
726
                    # Send first response for each request.n (index) with
                    # the role
727
                    role = self.get_chat_request_role(request)
728
729
730

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
731
                    for i in range(num_choices):
732
733
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
734
735
736
737
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
738
                            logprobs=None,
739
740
                            finish_reason=None,
                        )
741
742

                        # return prompt_token_ids at the first chunk ever
743
744
745
746
747
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
748
                            model=model_name,
749
750
751
752
753
754
                            prompt_token_ids=(
                                res.prompt_token_ids
                                if request.return_token_ids
                                else None
                            ),
                        )
755

756
757
758
759
760
                        # if continuous usage stats are requested, add it
                        if include_continuous_usage:
                            chunk.usage = UsageInfo(
                                prompt_tokens=num_prompt_tokens,
                                completion_tokens=0,
761
762
                                total_tokens=num_prompt_tokens,
                            )
763

764
765
766
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

767
768
                    # Send response to echo the input portion of the
                    # last message
769
                    if request.echo:
770
                        last_msg_content: str | list[dict[str, str]] = ""
771
772
773
774
775
                        if (
                            conversation
                            and "content" in conversation[-1]
                            and conversation[-1].get("role") == role
                        ):
776
                            last_msg_content = conversation[-1]["content"] or ""
777
778

                        if last_msg_content:
779
                            for i in range(num_choices):
780
781
782
783
784
785
                                choice_data = ChatCompletionResponseStreamChoice(
                                    index=i,
                                    delta=DeltaMessage(content=last_msg_content),
                                    logprobs=None,
                                    finish_reason=None,
                                )
786
787
788
789
790
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
791
792
                                    model=model_name,
                                )
793
794
795
796
                                if include_continuous_usage:
                                    chunk.usage = UsageInfo(
                                        prompt_tokens=num_prompt_tokens,
                                        completion_tokens=0,
797
798
                                        total_tokens=num_prompt_tokens,
                                    )
799

800
                                data = chunk.model_dump_json(exclude_unset=True)
801
802
803
804
805
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index
806
                    tool_parser = tool_parsers[i]
807

808
                    if (
809
                        reasoning_parser
810
811
812
813
814
815
816
817
                        and res.prompt_token_ids
                        and prompt_is_reasoning_end_arr[i] is None
                    ):
                        # only check once per choice, because prompt_token_ids
                        # are the same for all deltas in that choice
                        prompt_is_reasoning_end_arr[i] = (
                            reasoning_parser.is_reasoning_end(res.prompt_token_ids)
                        )
818
819
820
                    if finish_reason_sent[i]:
                        continue

821
                    if request.logprobs and request.top_logprobs is not None:
822
                        assert output.logprobs is not None, "Did not output logprobs"
823
                        logprobs = self._create_chat_logprobs(
824
825
                            token_ids=output.token_ids,
                            top_logprobs=output.logprobs,
826
                            tokenizer=tokenizer,
827
                            num_output_top_logprobs=request.top_logprobs,
828
                            return_as_token_id=request.return_tokens_as_token_ids,
829
830
831
832
                        )
                    else:
                        logprobs = None

833
834
                    if self.use_harmony:
                        harmony_parser = harmony_parsers[i]
835
                        prev_recipient = harmony_parser.current_recipient
836
837
838

                        # Track accumulated content per token with their state
                        token_states: list[TokenState] = []
839
840
                        for token_id in output.token_ids:
                            harmony_parser.process(token_id)
841
842
843
844
845
846
847
848
849
                            token_delta = harmony_parser.last_content_delta or ""
                            token_states.append(
                                TokenState(
                                    harmony_parser.current_channel,
                                    harmony_parser.current_recipient,
                                    token_delta,
                                )
                            )
                        delta_text = "".join(delta for _, _, delta in token_states)
850
                        cur_channel = harmony_parser.current_channel
851

852
853
854
855
856
                        # handle the case where several tokens where generated at once
                        # including the final token, leading to a delta in the text
                        # but the current channel to be empty (start state)
                        if not cur_channel and delta_text:
                            cur_channel = "final"
857
858
                    else:
                        delta_text = output.text
859

860
861
862
863
864
                    if (
                        not delta_text
                        and not output.token_ids
                        and not previous_num_tokens[i]
                    ):
865
866
867
                        # Chunked prefill case, don't return empty chunks
                        continue

868
                    delta_message: DeltaMessage | None
869

870
                    # just update previous_texts and previous_token_ids
871
                    if tool_choice_auto or reasoning_parser:
872
873
874
875
876
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_text = previous_texts[i]
                        previous_token_ids = all_previous_token_ids[i]
                        current_text = previous_text + delta_text
877
878
                        # avoid the None + list error.
                        if previous_token_ids:
879
                            current_token_ids = previous_token_ids + as_list(
880
881
                                output.token_ids
                            )
882
                        else:
883
                            current_token_ids = as_list(output.token_ids)
884

885
                    if self.use_harmony:
886
887
888
                        delta_message, tools_streamed_flag = (
                            extract_harmony_streaming_delta(
                                harmony_parser=harmony_parser,
889
                                token_states=token_states,
890
891
892
893
894
                                prev_recipient=prev_recipient,
                                include_reasoning=request.include_reasoning,
                            )
                        )
                        harmony_tools_streamed[i] |= tools_streamed_flag
895
                    # handle streaming deltas for tools with named tool_choice
896
                    elif tool_choice_function_name:
897
898
899
900
901
902
903
904
905
906
907
                        # When encountering think end id in prompt_token_ids
                        # i.e {"enable_thinking": False},
                        # check BEFORE calling the parser to avoid a spurious
                        # reasoning delta on the first chunk.
                        if (
                            reasoning_parser
                            and not reasoning_end_arr[i]
                            and prompt_is_reasoning_end_arr[i]
                        ):
                            reasoning_end_arr[i] = True

908
                        if (
909
                            reasoning_parser
910
911
912
913
914
                            and not reasoning_end_arr[i]
                            and not reasoning_parser.is_reasoning_end(
                                previous_token_ids
                            )
                        ):
915
916
                            assert reasoning_parser is not None
                            delta_message = (
917
                                reasoning_parser.extract_reasoning_streaming(
918
919
920
921
922
923
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output.token_ids,
924
925
                                )
                            )
926
                            # When encountering think end id in delta_token_ids,
927
                            # set reasoning status to end.
928
                            # Only keep 'content', remove 'reasoning'.
929
930
                            if reasoning_parser.is_reasoning_end(
                                as_list(output.token_ids)
931
                            ):
932
                                reasoning_end_arr[i] = True
933
934
935
936
937
938
939
940
                                if delta_message and delta_message.content:
                                    # This need to be added to next `delta_text`
                                    current_text = delta_message.content
                                    delta_message.content = None
                                else:
                                    current_text = ""
                        else:
                            # Just to add remaining `content`
941
                            if reasoning_parser:
942
943
944
                                delta_text = previous_text + delta_text
                                current_text = ""

945
946
                            if function_name_returned[i]:
                                delta_tool_call = DeltaToolCall(
947
948
949
                                    function=DeltaFunctionCall(arguments=delta_text),
                                    index=i,
                                )
950
                            else:
951
                                # Generate ID based on tokenizer type
952
                                if is_mistral_tokenizer(tokenizer):
953
954
955
956
957
958
959
                                    tool_call_id = MistralToolCall.generate_random_id()
                                else:
                                    tool_call_id = make_tool_call_id(
                                        id_type=self.tool_call_id_type,
                                        func_name=tool_choice_function_name,
                                        idx=history_tool_call_cnt,
                                    )
960
                                delta_tool_call = DeltaToolCall(
961
                                    id=tool_call_id,
962
963
964
                                    type="function",
                                    function=DeltaFunctionCall(
                                        name=tool_choice_function_name,
965
966
967
968
                                        arguments=delta_text,
                                    ),
                                    index=i,
                                )
969
                                function_name_returned[i] = True
970
                                history_tool_call_cnt += 1
971

972
973
974
975
976
                            delta_message = DeltaMessage(
                                tool_calls=[
                                    delta_tool_call,
                                ]
                            )
977
                            tools_streamed[i] = True
978

979
980
981
982
983
                    elif request.tool_choice == "required":
                        assert previous_texts is not None
                        previous_text = previous_texts[i]
                        current_text = previous_text + delta_text
                        fn_name_returned = function_name_returned[i]
984
985
986
                        output_token_ids = as_list(output.token_ids)

                        if (
987
                            reasoning_parser is not None
988
                            and not reasoning_end_arr[i]
989
                            and prompt_is_reasoning_end_arr[i]
990
991
                        ):
                            reasoning_end_arr[i] = True
992

993
                        if reasoning_parser and not reasoning_end_arr[i]:
994
                            delta_message = (
995
                                reasoning_parser.extract_reasoning_streaming(
996
997
998
999
1000
1001
1002
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output_token_ids,
                                )
1003
                            )
1004
1005
1006
1007
1008
1009
1010
1011
1012
                            if reasoning_parser.is_reasoning_end(output_token_ids):
                                reasoning_end_arr[i] = True
                                if delta_message and delta_message.content:
                                    current_text = delta_message.content
                                    delta_message.content = None
                                else:
                                    # reasoning ended
                                    current_text = ""

1013
                        else:
1014
                            # either finished reasoning or no reasoning at all
1015
                            content = current_text
1016
1017
1018
1019
1020
1021
1022
1023
1024

                            delta_message, function_name_returned[i] = (
                                self.extract_tool_call_required_streaming(
                                    previous_text=previous_text,
                                    current_text=content,
                                    delta_text=delta_text,
                                    function_name_returned=fn_name_returned,
                                    tool_call_idx=history_tool_call_cnt,
                                )
1025
                            )
1026
1027
1028
1029
1030
1031
1032
                            if (
                                delta_message
                                and delta_message.tool_calls
                                and delta_message.tool_calls[0].id is not None
                            ):
                                history_tool_call_cnt += 1
                                tools_streamed[i] = True
1033

1034
1035
                    # handle streaming deltas for tools with "auto" tool choice
                    # and reasoning parser
1036
                    elif tool_choice_auto and reasoning_parser:
1037
1038
1039
                        assert tool_parser is not None
                        assert added_content_delta_arr is not None
                        assert reasoning_end_arr is not None
1040
                        output_token_ids = as_list(output.token_ids)
1041
                        if not reasoning_end_arr[i]:
1042
1043
1044
                            # When encountering think end id in prompt_token_ids
                            # i.e {"enable_thinking": False},
                            # set reasoning status to end.
1045
                            if prompt_is_reasoning_end_arr[i]:
1046
                                reasoning_end_arr[i] = True
1047
                                current_token_ids = output_token_ids
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
                                # Don't update current_text, keep it as is from delta
                            else:
                                delta_message = (
                                    reasoning_parser.extract_reasoning_streaming(
                                        previous_text,
                                        current_text,
                                        delta_text,
                                        previous_token_ids,
                                        current_token_ids,
                                        output_token_ids,
1058
1059
                                    )
                                )
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076

                                # When encountering think end id in delta_token_ids,
                                # set reasoning status to end.
                                # Remove the text and token ids related
                                # to 'reasoning'.
                                if reasoning_parser.is_reasoning_end(output_token_ids):
                                    reasoning_end_arr[i] = True
                                    current_token_ids = (
                                        reasoning_parser.extract_content_ids(
                                            output_token_ids
                                        )
                                    )
                                    if delta_message and delta_message.content:
                                        current_text = delta_message.content
                                        delta_message.content = None
                                    else:
                                        current_text = ""
1077
1078

                        # handle tool calls only after reasoning is done,
1079
                        if reasoning_end_arr[i]:
1080
                            delta_token_ids = output_token_ids
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
                            # First time to tool call,
                            # add the remaining text and token ids
                            # to delta from previous
                            if not added_content_delta_arr[i]:
                                added_content_delta_arr[i] = True
                                previous_text = ""
                                previous_token_ids = []
                                delta_text = current_text
                                delta_token_ids = current_token_ids

1091
                            delta_message = tool_parser.extract_tool_calls_streaming(
1092
1093
                                previous_text=previous_text,
                                current_text=current_text,
1094
                                delta_text=delta_text,
1095
1096
                                previous_token_ids=previous_token_ids,
                                current_token_ids=current_token_ids,
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
                                delta_token_ids=delta_token_ids,
                                request=request,
                            )
                            if delta_message and delta_message.tool_calls:
                                tools_streamed[i] = True
                    # when only tool calls
                    elif tool_choice_auto:
                        assert tool_parser is not None
                        delta_message = tool_parser.extract_tool_calls_streaming(
                            previous_text=previous_text,
                            current_text=current_text,
                            delta_text=delta_text,
                            previous_token_ids=previous_token_ids,
                            current_token_ids=current_token_ids,
                            delta_token_ids=output.token_ids,
                            request=request,
                        )
1114
1115
                        if delta_message and delta_message.tool_calls:
                            tools_streamed[i] = True
1116

1117
                    # when only reasoning
1118
                    elif reasoning_parser:
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
                        # When encountering think end id in prompt_token_ids
                        # i.e {"enable_thinking": False},
                        # set reasoning status to end.
                        # Route all generated tokens as content directly.
                        if prompt_is_reasoning_end_arr[i]:
                            delta_message = DeltaMessage(content=delta_text)
                        else:
                            delta_message = (
                                reasoning_parser.extract_reasoning_streaming(
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output.token_ids,
                                )
                            )
1136
                    # handle streaming just a content delta
1137
1138
1139
                    else:
                        delta_message = DeltaMessage(content=delta_text)

1140
                    # update the previous values for the next iteration
1141
                    if (tool_choice_auto or reasoning_parser) and not self.use_harmony:
1142
1143
1144
1145
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_texts[i] = current_text
                        all_previous_token_ids[i] = current_token_ids
1146
1147
1148
1149
                    else:
                        # Update for comprehensive logging even in simple case
                        assert previous_texts is not None
                        previous_texts[i] += delta_text
1150

1151
                    # set the previous values for the next iteration
1152
                    previous_num_tokens[i] += len(output.token_ids)
1153
1154
1155
1156
1157
1158

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
1159
1160
1161
1162
1163
1164
1165
                        # NOTE: If return_token_ids is enabled, we still need to
                        # send a chunk with token_ids even if delta_message is None
                        # to ensure all tokens are included in the response
                        if (
                            output.finish_reason is None
                            and not request.return_token_ids
                        ):
1166
                            continue
1167
                        delta_message = DeltaMessage()
1168

1169
1170
                    # Log streaming delta if output logging is enabled
                    if self.enable_log_outputs and self.request_logger:
1171
                        delta_content_parts = []
1172
                        if delta_message.content:
1173
                            delta_content_parts.append(delta_message.content)
1174
1175
                        if delta_message.reasoning:
                            reasoning = delta_message.reasoning
1176
1177
1178
                            delta_content_parts.append(f"[reasoning: {reasoning}]")
                        if delta_message.tool_calls:
                            tool_args = "".join(
1179
1180
                                tc.function.arguments
                                for tc in delta_message.tool_calls
1181
1182
                                if tc.function and tc.function.arguments
                            )
1183
1184
                            if tool_args:
                                delta_content_parts.append(f"[tool_calls: {tool_args}]")
1185

1186
1187
                        if delta_content_parts and self.enable_log_deltas:
                            delta_content = " ".join(delta_content_parts)
1188
1189
1190
                            self.request_logger.log_outputs(
                                request_id=request_id,
                                outputs=delta_content,
1191
                                output_token_ids=as_list(output.token_ids),
1192
1193
1194
1195
1196
                                finish_reason=output.finish_reason,
                                is_streaming=True,
                                delta=True,
                            )

1197
1198
1199
1200
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
1201
                            delta=delta_message,
1202
                            logprobs=logprobs,
1203
                            finish_reason=None,
1204
1205
1206
1207
1208
1209
                            token_ids=(
                                as_list(output.token_ids)
                                if request.return_token_ids
                                else None
                            ),
                        )
1210
1211

                    # if the model is finished generating
1212
                    else:
1213
1214
1215
1216
                        # check for error finish reason and abort streaming
                        # finish_reason='error' indicates a retryable error
                        self._raise_if_error(output.finish_reason, request_id)

1217
1218
1219
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
1220
                        # only happens if we are NOT using structured outputs
1221
                        auto_tools_called = False
1222
                        if tool_parser:
1223
1224
1225
1226
1227
1228
                            auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
                            index = (
                                len(tool_parser.prev_tool_call_arr) - 1
                                if auto_tools_called
                                else 0
                            )
1229
1230
1231
                        else:
                            index = 0

1232
1233
1234
1235
1236
1237
                        if (
                            self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output
                            )
                            and tool_parser
                        ):
1238
                            latest_delta_len = 0
1239
1240
                            if (
                                isinstance(
1241
                                    delta_message.tool_calls[0].function,
1242
1243
1244
1245
1246
                                    DeltaFunctionCall,
                                )
                            ) and isinstance(
                                delta_message.tool_calls[0].function.arguments, str
                            ):
1247
                                latest_delta_len = len(
1248
1249
                                    delta_message.tool_calls[0].function.arguments
                                )
1250

1251
1252
1253
1254
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
1255
1256
1257
1258
                                    "arguments", {}
                                ),
                                ensure_ascii=False,
                            )
1259

1260
                            # get what we've streamed so far for arguments
1261
                            # for the current tool
1262
1263
                            actual_call = tool_parser.streamed_args_for_tool[index]
                            if latest_delta_len > 0:
1264
                                actual_call = actual_call[:-latest_delta_len]
1265
1266

                            # check to see if there's anything left to stream
1267
                            remaining_call = expected_call.replace(actual_call, "", 1)
1268
                            # set that as a delta message
1269
1270
                            delta_message = self._create_remaining_args_delta(
                                delta_message, remaining_call, index
1271
                            )
1272

1273
                        # Send the finish response for each request.n only once
1274
1275
1276
1277
                        # In OpenAI's API, when a tool is called, the
                        # finish_reason is:
                        # "tool_calls" for "auto" or "required" tool calls,
                        # and "stop" for named tool calls.
1278
1279
                        if (
                            auto_tools_called
1280
                            or (tools_streamed[i] and not tool_choice_function_name)
1281
1282
                            or (self.use_harmony and harmony_tools_streamed[i])
                        ):
1283
1284
                            finish_reason_ = "tool_calls"
                        else:
1285
1286
1287
                            finish_reason_ = (
                                output.finish_reason if output.finish_reason else "stop"
                            )
1288
1289
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
1290
                            delta=delta_message,
1291
                            logprobs=logprobs,
1292
                            finish_reason=finish_reason_,
1293
                            stop_reason=output.stop_reason,
1294
1295
1296
1297
1298
1299
                            token_ids=(
                                as_list(output.token_ids)
                                if request.return_token_ids
                                else None
                            ),
                        )
1300

1301
                        finish_reason_sent[i] = True
1302

1303
                    choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
1304
1305
1306
1307
1308
                    chunk = ChatCompletionStreamResponse(
                        id=request_id,
                        object=chunk_object_type,
                        created=created_time,
                        choices=[choice_data],
1309
1310
                        model=model_name,
                    )
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320

                    # handle usage stats if requested & if continuous
                    if include_continuous_usage:
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=num_prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=num_prompt_tokens + completion_tokens,
                        )

1321
                    data = chunk.model_dump_json(exclude_unset=True)
1322
1323
                    yield f"data: {data}\n\n"

1324
1325
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
1326
1327
            if include_usage:
                completion_tokens = sum(previous_num_tokens)
1328
1329
1330
1331
1332
                final_usage = UsageInfo(
                    prompt_tokens=num_prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=num_prompt_tokens + completion_tokens,
                )
1333
1334
                if self.enable_prompt_tokens_details and num_cached_tokens:
                    final_usage.prompt_tokens_details = PromptTokenUsageInfo(
1335
1336
                        cached_tokens=num_cached_tokens
                    )
1337
1338
1339
1340
1341
1342
1343

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
1344
1345
1346
1347
1348
                    usage=final_usage,
                )
                final_usage_data = final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True
                )
1349
                yield f"data: {final_usage_data}\n\n"
1350

1351
1352
1353
1354
1355
            # report to FastAPI middleware aggregate usage across all choices
            num_completion_tokens = sum(previous_num_tokens)
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=num_completion_tokens,
1356
1357
1358
1359
1360
1361
1362
1363
1364
                total_tokens=num_prompt_tokens + num_completion_tokens,
            )

            # Log complete streaming response if output logging is enabled
            if self.enable_log_outputs and self.request_logger:
                # Log the complete response for each choice
                for i in range(num_choices):
                    full_text = (
                        previous_texts[i]
1365
1366
                        if previous_texts and i < len(previous_texts)
                        else f"<streaming_complete: {previous_num_tokens[i]} tokens>"
1367
1368
1369
1370
                    )
                    self.request_logger.log_outputs(
                        request_id=request_id,
                        outputs=full_text,
1371
                        output_token_ids=None,  # Consider also logging all token IDs
1372
1373
1374
1375
                        finish_reason="streaming_complete",
                        is_streaming=True,
                        delta=False,
                    )
1376

1377
1378
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
1379
        except Exception as e:
1380
            logger.exception("Error in chat completion stream generator.")
1381
            data = self.create_streaming_error_response(e)
1382
            yield f"data: {data}\n\n"
1383
1384
1385
1386
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
1387
1388
1389
1390
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
1391
        model_name: str,
1392
        conversation: list[ConversationMessage],
1393
        tokenizer: TokenizerLike,
1394
        request_metadata: RequestResponseMetadata,
1395
        reasoning_parser: ReasoningParser | None = None,
1396
    ) -> ErrorResponse | ChatCompletionResponse:
1397
1398
        from vllm.tokenizers.mistral import MistralTokenizer

1399
        created_time = int(time.time())
1400
        final_res: RequestOutput | None = None
1401

1402
1403
1404
1405
1406
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
1407
        except ValueError as e:
1408
            return self.create_error_response(e)
1409

1410
1411
        assert final_res is not None

1412
        choices: list[ChatCompletionResponseChoice] = []
1413
        if self.tool_call_id_type == "kimi_k2":
1414
1415
1416
            history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
        else:
            history_tool_call_cnt = 0
1417

1418
1419
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
1420
1421
1422
            # check for error finish reason and raise GenerationError
            # finish_reason='error' indicates a retryable request-level internal error
            self._raise_if_error(output.finish_reason, request_id)
1423
            token_ids = output.token_ids
1424
            out_logprobs = output.logprobs
1425
            tool_call_info = None
1426

1427
1428
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
1429
                logprobs = self._create_chat_logprobs(
1430
                    token_ids=token_ids,
1431
                    top_logprobs=out_logprobs,
1432
                    num_output_top_logprobs=request.top_logprobs,
1433
                    tokenizer=tokenizer,
1434
                    return_as_token_id=request.return_tokens_as_token_ids,
1435
1436
1437
                )
            else:
                logprobs = None
1438
1439

            if self.use_harmony:
1440
                reasoning, content, _ = parse_chat_output(token_ids)
1441
                if not request.include_reasoning:
1442
                    reasoning = None
1443

1444
                if self.tool_parser is not None:
1445
1446
1447
1448
1449
                    if tokenizer is None:
                        raise ValueError(
                            "Tokenizer not available when `skip_tokenizer_init=True`"
                        )

1450
1451
1452
1453
1454
1455
1456
                    tool_parser = self.tool_parser(tokenizer)
                    # NOTE: We use token_ids for openai tool parser
                    tool_call_info = tool_parser.extract_tool_calls(
                        "",
                        request=request,
                        token_ids=token_ids,  # type: ignore
                    )
1457
                    content = tool_call_info.content
1458
1459
                    message = ChatMessage(
                        role=role,
1460
                        reasoning=reasoning,
1461
1462
1463
1464
1465
1466
                        content=content,
                        tool_calls=tool_call_info.tool_calls,
                    )
                else:
                    message = ChatMessage(
                        role=role,
1467
                        reasoning=reasoning,
1468
1469
                        content=content,
                    )
1470
1471
1472
1473
1474

                choice_data = ChatCompletionResponseChoice(
                    index=output.index,
                    message=message,
                    logprobs=logprobs,
1475
1476
1477
1478
1479
1480
1481
                    finish_reason=(
                        "tool_calls"
                        if (tool_call_info is not None and tool_call_info.tools_called)
                        else output.finish_reason
                        if output.finish_reason
                        else "stop"
                    ),
1482
                    stop_reason=output.stop_reason,
1483
1484
1485
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
1486
1487
1488
                )
                choices.append(choice_data)
                continue
1489

1490
            if reasoning_parser:
1491
1492
                # If the reasoning parser is enabled,
                # tool calls are extracted exclusively from the content.
1493
                reasoning, content = reasoning_parser.extract_reasoning(
1494
1495
                    output.text, request=request
                )
1496
                if not request.include_reasoning:
1497
                    reasoning = None
1498
            else:
1499
                reasoning = None
1500
                content = output.text
1501

1502
            auto_tools_called = False
1503
1504
            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
1505
1506
1507
1508
1509
1510
1511
1512
            tool_calls, content = self._parse_tool_calls_from_content(
                request=request,
                tokenizer=tokenizer,
                content=content,
                enable_auto_tools=self.enable_auto_tools,
                tool_parser_cls=self.tool_parser,
            )
            tool_call_class = (
1513
                MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
1514
            )
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
            if self.use_harmony:
                # Harmony models already have parsed content and tool_calls
                # through parse_chat_output. Respect its output directly.
                message = ChatMessage(
                    role=role,
                    reasoning=reasoning,
                    content=content,
                    tool_calls=tool_calls if tool_calls else [],
                )

            elif (not self.enable_auto_tools or not self.tool_parser) and (
1526
1527
1528
                not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
                and request.tool_choice != "required"
            ):
1529
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1530

1531
1532
1533
1534
            elif (
                request.tool_choice
                and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
            ):
1535
                assert tool_calls is not None and len(tool_calls) > 0
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
                tool_call_class_items = []
                for idx, tc in enumerate(tool_calls):
                    # Use native ID if available (e.g., Kimi K2),
                    # otherwise generate ID with correct id_type
                    if tc.id:
                        tool_call_class_items.append(
                            tool_call_class(id=tc.id, function=tc)
                        )
                    else:
                        # Generate ID using the correct format (kimi_k2 or random),
                        # but leave it to the class if it's Mistral to preserve
                        # 9-char IDs
                        if isinstance(tokenizer, MistralTokenizer):
                            tool_call_class_items.append(tool_call_class(function=tc))
                        else:
                            generated_id = make_tool_call_id(
                                id_type=self.tool_call_id_type,
                                func_name=tc.name,
1554
                                idx=history_tool_call_cnt,
1555
1556
1557
1558
1559
                            )
                            tool_call_class_items.append(
                                tool_call_class(id=generated_id, function=tc)
                            )
                    history_tool_call_cnt += 1
1560
1561
                message = ChatMessage(
                    role=role,
1562
                    reasoning=reasoning,
1563
                    content="",
1564
                    tool_calls=tool_call_class_items,
1565
                )
1566

1567
            elif request.tool_choice and request.tool_choice == "required":
1568
1569
                tool_call_class_items = []
                assert tool_calls is not None and len(tool_calls) > 0
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
                for idx, tool_call in enumerate(tool_calls):
                    # Use native ID if available,
                    # otherwise generate ID with correct id_type
                    if tool_call.id:
                        tool_call_class_items.append(
                            tool_call_class(id=tool_call.id, function=tool_call)
                        )
                    else:
                        # Generate ID using the correct format (kimi_k2 or random),
                        # but leave it to the class if it's Mistral to preserve
                        # 9-char IDs
                        if isinstance(tokenizer, MistralTokenizer):
                            tool_call_class_items.append(
                                tool_call_class(function=tool_call)
                            )
                        else:
                            generated_id = make_tool_call_id(
1587
1588
                                id_type=self.tool_call_id_type,
                                func_name=tool_call.name,
1589
                                idx=history_tool_call_cnt,
1590
1591
1592
1593
                            )
                            tool_call_class_items.append(
                                tool_call_class(id=generated_id, function=tool_call)
                            )
1594
                    history_tool_call_cnt += 1
1595
1596
1597
                message = ChatMessage(
                    role=role,
                    content="",
1598
                    tool_calls=tool_call_class_items,
1599
                    reasoning=reasoning,
1600
                )
1601

1602
1603
            # if the request doesn't use tool choice
            # OR specifies to not use a tool
1604
            elif not request.tool_choice or request.tool_choice == "none":
1605
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1606
1607

            # handle when there are tools and tool choice is auto
1608
1609
1610
1611
1612
1613
            elif (
                request.tools
                and (request.tool_choice == "auto" or request.tool_choice is None)
                and self.enable_auto_tools
                and self.tool_parser
            ):
1614
1615
1616
                # In the OpenAI API the finish_reason is "tools_called"
                # if the tool choice is auto and the model produced a tool
                # call. The same is not true for named function calls
1617
1618
                auto_tools_called = tool_calls is not None and len(tool_calls) > 0
                if tool_calls:
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
                    tool_call_items = []
                    for idx, tc in enumerate(tool_calls):
                        # Use native ID if available (e.g., Kimi K2),
                        # otherwise generate ID with correct id_type
                        if tc.id:
                            tool_call_items.append(
                                tool_call_class(id=tc.id, function=tc)
                            )
                        else:
                            # Generate ID using the correct format (kimi_k2 or random),
                            # but leave it to the class if it's Mistral to preserve
                            # 9-char IDs
                            if isinstance(tokenizer, MistralTokenizer):
                                tool_call_items.append(tool_call_class(function=tc))
                            else:
                                generated_id = make_tool_call_id(
                                    id_type=self.tool_call_id_type,
                                    func_name=tc.name,
1637
                                    idx=history_tool_call_cnt,
1638
1639
1640
1641
1642
                                )
                                tool_call_items.append(
                                    tool_call_class(id=generated_id, function=tc)
                                )
                        history_tool_call_cnt += 1
1643
1644
                    message = ChatMessage(
                        role=role,
1645
                        reasoning=reasoning,
1646
                        content=content,
1647
                        tool_calls=tool_call_items,
1648
                    )
1649
1650
1651
1652

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
1653
1654
1655
1656
                    ret_content = content

                    # try to use content return from tool parser first,
                    # tool parser may do some modify for the content.
1657
1658
                    if content and len(content) > 0:
                        ret_content = content
1659
1660
                    message = ChatMessage(
                        role=role,
1661
                        reasoning=reasoning,
1662
1663
                        content=ret_content,
                    )
1664
1665
1666
1667
1668
1669

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
1670
1671
                    "completion."
                )
1672
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1673
1674
1675
1676
1677
1678
1679
1680
            # In OpenAI's API, when a tool is called, the finish_reason is:
            # "tool_calls" for "auto" or "required" tool calls,
            # and "stop" for named tool calls.
            is_finish_reason_tool_calls = auto_tools_called or (
                request.tool_choice
                and request.tool_choice == "required"
                and output.finish_reason == "stop"
            )
1681

1682
1683
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
1684
                message=message,
1685
                logprobs=logprobs,
1686
1687
1688
1689
1690
                finish_reason="tool_calls"
                if is_finish_reason_tool_calls
                else output.finish_reason
                if output.finish_reason
                else "stop",
1691
                stop_reason=output.stop_reason,
1692
1693
1694
                token_ids=(
                    as_list(output.token_ids) if request.return_token_ids else None
                ),
1695
            )
1696
            choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
1697

1698
1699
            choices.append(choice_data)

1700
        if request.echo:
1701
            last_msg_content: str | list[dict[str, str]] = ""
1702
1703
1704
1705
1706
            if (
                conversation
                and "content" in conversation[-1]
                and conversation[-1].get("role") == role
            ):
1707
                last_msg_content = conversation[-1]["content"] or ""
1708
            if isinstance(last_msg_content, list):
1709
                last_msg_content = "\n".join(msg["text"] for msg in last_msg_content)
1710
1711

            for choice in choices:
1712
                full_message = last_msg_content + (choice.message.content or "")
1713
1714
                choice.message.content = full_message

1715
        assert final_res.prompt_token_ids is not None
1716
        num_prompt_tokens = len(final_res.prompt_token_ids)
1717
1718
        if final_res.encoder_prompt_token_ids is not None:
            num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
1719
        num_generated_tokens = sum(
1720
1721
1722
1723
1724
1725
1726
            len(output.token_ids) for output in final_res.outputs
        )
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
1727
1728
        if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
            usage.prompt_tokens_details = PromptTokenUsageInfo(
1729
1730
                cached_tokens=final_res.num_cached_tokens
            )
1731
1732
1733

        request_metadata.final_usage_info = usage

1734
1735
1736
1737
1738
1739
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
1740
            prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
1741
1742
1743
            prompt_token_ids=(
                final_res.prompt_token_ids if request.return_token_ids else None
            ),
Robert Shaw's avatar
Robert Shaw committed
1744
            kv_transfer_params=final_res.kv_transfer_params,
1745
1746
        )

1747
1748
1749
1750
1751
1752
1753
1754
1755
        # Log complete response if output logging is enabled
        if self.enable_log_outputs and self.request_logger:
            for choice in choices:
                output_text = ""
                if choice.message.content:
                    output_text = choice.message.content
                elif choice.message.tool_calls:
                    # For tool calls, log the function name and arguments
                    tool_call_descriptions = []
1756
1757
1758
1759
1760
                    for tc in choice.message.tool_calls:  # type: ignore
                        function_call: FunctionCall = tc.function  # type: ignore
                        tool_call_descriptions.append(
                            f"{function_call.name}({function_call.arguments})"
                        )
1761
1762
1763
1764
1765
1766
1767
                    tool_calls_str = ", ".join(tool_call_descriptions)
                    output_text = f"[tool_calls: {tool_calls_str}]"

                if output_text:
                    # Get the corresponding output token IDs
                    output_token_ids = None
                    if choice.index < len(final_res.outputs):
1768
                        output_token_ids = final_res.outputs[choice.index].token_ids
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778

                    self.request_logger.log_outputs(
                        request_id=request_id,
                        outputs=output_text,
                        output_token_ids=output_token_ids,
                        finish_reason=choice.finish_reason,
                        is_streaming=False,
                        delta=False,
                    )

1779
        return response
1780
1781

    def _get_top_logprobs(
1782
1783
        self,
        logprobs: dict[int, Logprob],
1784
        top_logprobs: int | None,
1785
        tokenizer: TokenizerLike | None,
1786
1787
        should_return_as_token_id: bool,
    ) -> list[ChatCompletionLogProb]:
1788
        return [
1789
            ChatCompletionLogProb(
1790
1791
1792
1793
1794
1795
1796
1797
                token=(
                    token := self._get_decoded_token(
                        p[1],
                        p[0],
                        tokenizer,
                        return_as_token_id=should_return_as_token_id,
                    )
                ),
1798
1799
                logprob=max(p[1].logprob, -9999.0),
                bytes=list(token.encode("utf-8", errors="replace")),
1800
1801
            )
            for i, p in enumerate(logprobs.items())
1802
            if (top_logprobs and i < top_logprobs or top_logprobs == -1)
1803
1804
1805
1806
1807
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
1808
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
1809
        tokenizer: TokenizerLike | None,
1810
1811
        num_output_top_logprobs: int | None = None,
        return_as_token_id: bool | None = None,
1812
1813
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
1814
        logprobs_content: list[ChatCompletionLogProbsContent] = []
1815

1816
1817
1818
1819
1820
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
1821
1822
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
1823
            if step_top_logprobs is None or step_top_logprobs.get(token_id) is None:
1824
                if should_return_as_token_id:
1825
                    token = f"token_id:{token_id}"
1826
                else:
1827
1828
                    if tokenizer is None:
                        raise ValueError(
1829
                            "Unable to get tokenizer because `skip_tokenizer_init=True`"
1830
1831
                        )

1832
                    token = tokenizer.decode(token_id)
1833

1834
1835
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
1836
                        token=token,
1837
                        bytes=list(token.encode("utf-8", errors="replace")),
1838
1839
                    )
                )
1840
            else:
1841
1842
1843
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

1844
1845
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
1846
                        token=self._get_decoded_token(
1847
1848
1849
                            step_token,
                            token_id,
                            tokenizer,
1850
                            should_return_as_token_id,
1851
1852
                        ),
                        logprob=max(step_token.logprob, -9999.0),
1853
1854
1855
1856
1857
                        bytes=(
                            None
                            if step_decoded is None
                            else list(step_decoded.encode("utf-8", errors="replace"))
                        ),
1858
                        top_logprobs=self._get_top_logprobs(
1859
1860
1861
1862
1863
1864
1865
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                            should_return_as_token_id,
                        ),
                    )
                )
1866
1867

        return ChatCompletionLogProbs(content=logprobs_content)
1868

1869
    def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest):
1870
1871
1872
1873
1874
1875
1876
1877
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
1878
1879
1880
1881
1882
1883
        return (
            request.tools
            and self.tool_parser
            and self.enable_auto_tools
            and request.tool_choice in ["auto", None]
        )
1884
1885
1886

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
1887
        delta_message: DeltaMessage | None,
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
1899
            output.finish_reason is not None
1900
1901
1902
1903
1904
            and self.enable_auto_tools
            and self.tool_parser
            and delta_message
            and delta_message.tool_calls
            and delta_message.tool_calls[0]
1905
1906
1907
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
        )
1908

1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
    @staticmethod
    def _create_remaining_args_delta(
        delta_message: DeltaMessage,
        remaining_call: str,
        index: int,
    ) -> DeltaMessage:
        """
        Create a delta message for remaining tool arguments, preserving
        id/type/name from the original delta.
        """
        original_tc = next(
            (tc for tc in delta_message.tool_calls if tc.index == index),
            None,
        )
        original_fn = original_tc.function if original_tc else None
        return DeltaMessage(
            tool_calls=[
                DeltaToolCall(
                    index=index,
                    id=original_tc.id if original_tc else None,
                    type=original_tc.type if original_tc else None,
                    function=DeltaFunctionCall(
                        name=original_fn.name if original_fn else None,
                        arguments=remaining_call,
                    ),
                )
            ]
        )

1938
1939
1940
    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
1941
        should_include_tools: bool = True,
1942
1943
1944
    ):
        messages: list[OpenAIMessage] = []

1945
1946
1947
        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
1948
        _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
1949

1950
1951
1952
1953
1954
1955
1956
1957
        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
        sys_msg = get_system_message(
            reasoning_effort=request.reasoning_effort,
            browser_description=None,
1958
            python_description=None,
1959
            with_custom_tools=should_include_tools,
1960
        )
1961
1962
1963
        messages.append(sys_msg)

        # Add developer message.
1964
1965
        if request.tools:
            dev_msg = get_developer_message(
1966
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
1967
1968
            )
            messages.append(dev_msg)
1969
1970

        # Add user message.
1971
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
1972
1973
1974

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
1975
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1976
1977
1978
1979
1980

        # Add cache_salt if provided in the request
        if request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1981
        return messages, [engine_prompt]