serving.py 86.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import time
7
8
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
9
from typing import Any, Final
10

11
import jinja2
12
import partial_json_parser
13
import regex as re
14
from fastapi import Request
15
from openai_harmony import Message as OpenAIMessage
16
from partial_json_parser.core.options import Allow
17

18
from vllm.engine.protocol import EngineClient
19
20
21
22
23
24
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
    ConversationMessage,
    get_history_tool_calls_cnt,
    make_tool_call_id,
)
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.chat_completion.protocol import (
27
28
29
30
31
32
33
34
35
36
    ChatCompletionLogProb,
    ChatCompletionLogProbs,
    ChatCompletionLogProbsContent,
    ChatCompletionNamedToolChoiceParam,
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
37
38
)
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
39
    TokenState,
40
41
42
    extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
43
44
45
46
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ErrorResponse,
47
    FunctionCall,
48
49
50
51
52
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    ToolCall,
    UsageInfo,
)
53
from vllm.entrypoints.openai.engine.serving import (
54
55
56
57
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
58
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
59
60
61
62
63
64
65
66
67
from vllm.entrypoints.openai.parser.harmony_utils import (
    get_developer_message,
    get_stop_tokens_for_assistant_actions,
    get_streamable_parser_for_assistant,
    get_system_message,
    parse_chat_inputs_to_harmony_messages,
    parse_chat_output,
    render_for_completion,
)
68
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
69
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
70
from vllm.inputs.data import TokensPrompt
71
from vllm.logger import init_logger
72
from vllm.logprobs import Logprob
73
from vllm.outputs import CompletionOutput, RequestOutput
74
from vllm.parser import ParserManager
75
from vllm.reasoning import ReasoningParser
76
from vllm.renderers.inputs import TokPrompt
77
from vllm.sampling_params import BeamSearchParams, SamplingParams
78
79
80
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
    MistralTokenizer,
81
82
83
84
    maybe_serialize_tool_calls,
    truncate_tool_call_ids,
    validate_request_params,
)
85
86
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
87
from vllm.tool_parsers.utils import partial_json_loads
88
from vllm.utils.collection_utils import as_list
89
90
91
92
93

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):
94
95
96
    def __init__(
        self,
        engine_client: EngineClient,
97
        models: OpenAIServingModels,
98
99
        response_role: str,
        *,
100
101
        request_logger: RequestLogger | None,
        chat_template: str | None,
102
        chat_template_content_format: ChatTemplateContentFormatOption,
103
        trust_request_chat_template: bool = False,
104
        return_tokens_as_token_ids: bool = False,
105
        reasoning_parser: str = "",
106
        enable_auto_tools: bool = False,
107
        exclude_tools_when_tool_choice_none: bool = False,
108
        tool_parser: str | None = None,
109
        enable_prompt_tokens_details: bool = False,
110
        enable_force_include_usage: bool = False,
111
        enable_log_outputs: bool = False,
112
        enable_log_deltas: bool = True,
113
        log_error_stack: bool = False,
114
        default_chat_template_kwargs: dict[str, Any] | None = None,
115
    ) -> None:
116
117
118
119
120
121
122
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
            log_error_stack=log_error_stack,
        )
123

124
        self.response_role = response_role
125
126
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
127
        self.trust_request_chat_template = trust_request_chat_template
128
        self.default_chat_template_kwargs = default_chat_template_kwargs or {}
129
        self.enable_log_outputs = enable_log_outputs
130
        self.enable_log_deltas = enable_log_deltas
131

132
        # set up reasoning parser
133
        self.reasoning_parser_cls = ParserManager.get_reasoning_parser(
134
135
            reasoning_parser_name=reasoning_parser
        )
136
137
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
138
139
140
141
        self.tool_parser = ParserManager.get_tool_parser(
            tool_parser_name=tool_parser,
            enable_auto_tools=enable_auto_tools,
            model_name=self.model_config.model,
142
143
        )
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
144

145
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
146
        self.enable_force_include_usage = enable_force_include_usage
147
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
148
        self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
149
150
151
152
        if self.use_harmony:
            if "stop_token_ids" not in self.default_sampling_params:
                self.default_sampling_params["stop_token_ids"] = []
            self.default_sampling_params["stop_token_ids"].extend(
153
154
                get_stop_tokens_for_assistant_actions()
            )
155

156
157
158
159
160
161
162
163
164
165
        # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
        hf_overrides = getattr(self.model_config, "hf_overrides", None)
        if self.model_config.hf_text_config.model_type == "kimi_k2" or (
            isinstance(hf_overrides, dict)
            and hf_overrides.get("model_type") == "kimi_k2"
        ):
            self.tool_call_id_type = "kimi_k2"
        else:
            self.tool_call_id_type = "random"

166
167
168
169
170
171
172
173
174
175
        # NOTE(woosuk): While OpenAI's chat completion API supports browsing
        # for some models, currently vLLM doesn't support it. Please use the
        # Responses API instead.
        self.supports_browsing = False
        self.browser_tool = None
        # NOTE(woosuk): Chat completion API does not support code interpreter.
        # Please use the Responses API instead.
        self.supports_code_interpreter = False
        self.python_tool = None

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    async def warmup(self) -> None:
        """
        Warm up the chat template processing to avoid first-request latency.

        This method triggers Jinja2 template compilation and content format
        detection that would otherwise happen on the first real request,
        causing increased latency on the first request.
        """
        logger.info("Warming up chat template processing...")
        start_time = time.perf_counter()

        try:
            # Create a minimal dummy request
            dummy_request = ChatCompletionRequest(
                messages=[{"role": "user", "content": "warmup"}],
                model=None,
                max_completion_tokens=1,
            )

            # Call _preprocess_chat to trigger template compilation
            # This forces:
            # 1. Chat template content format detection
            # 2. Jinja2 template compilation
            # 3. Tokenizer initialization for chat
            await self._preprocess_chat(
                dummy_request,
                dummy_request.messages,
203
204
205
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=self.default_chat_template_kwargs,
206
207
208
209
210
211
212
213
214
            )

            elapsed = (time.perf_counter() - start_time) * 1000
            logger.info("Chat template warmup completed in %.1fms", elapsed)

        except Exception:
            # Log but don't fail server startup if warmup fails
            logger.exception("Chat template warmup failed")

215
    async def render_chat_request(
216
217
        self,
        request: ChatCompletionRequest,
218
    ) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
219
        """
220
        render chat request by validating and preprocessing inputs.
221

222
223
224
        Returns:
            A tuple of (conversation, engine_prompts) on success,
            or an ErrorResponse on failure.
225
226
227
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
228
            logger.error("Error with model %s", error_check_ret)
229
230
            return error_check_ret

231
232
233
234
235
236
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

237
        try:
238
            tokenizer = self.renderer.tokenizer
239

240
241
            tool_parser = self.tool_parser

242
            if isinstance(tokenizer, MistralTokenizer):
243
244
245
                # because of issues with pydantic we need to potentially
                # re-serialize the tool_calls field of the request
                # for more info: see comment in `maybe_serialize_tool_calls`
246
247
                maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
                truncate_tool_call_ids(request)  # type: ignore[arg-type]
248
                validate_request_params(request)
249

250
251
252
            # Check if tool parsing is unavailable (common condition)
            tool_parsing_unavailable = (
                tool_parser is None
253
254
                and not isinstance(tokenizer, MistralTokenizer)
                and not self.use_harmony
255
256
257
258
259
260
            )

            # Validate tool_choice when tool parsing is required but unavailable
            if tool_parsing_unavailable and request.tool_choice not in (
                None,
                "none",
261
            ):
262
263
264
265
266
267
268
269
270
271
272
273
274
                if request.tool_choice == "auto" and not self.enable_auto_tools:
                    # for hf tokenizers, "auto" tools requires
                    # --enable-auto-tool-choice and --tool-call-parser
                    return self.create_error_response(
                        '"auto" tool choice requires '
                        "--enable-auto-tool-choice and --tool-call-parser to be set"
                    )
                elif request.tool_choice != "auto":
                    # "required" or named tool requires tool parser
                    return self.create_error_response(
                        f'tool_choice="{request.tool_choice}" requires '
                        "--tool-call-parser to be set"
                    )
275

276
277
278
279
            if request.tools is None or (
                request.tool_choice == "none"
                and self.exclude_tools_when_tool_choice_none
            ):
280
281
282
                tool_dicts = None
            else:
                tool_dicts = [tool.model_dump() for tool in request.tools]
283

284
285
            if not self.use_harmony:
                # Common case.
286
287
288
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
289
                    trust_request_chat_template=self.trust_request_chat_template,
290
291
292
                )
                if error_check_ret is not None:
                    return error_check_ret
293

294
                conversation, engine_prompts = await self._preprocess_chat(
295
296
                    request,
                    request.messages,
297
298
299
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=self.default_chat_template_kwargs,
300
301
302
303
304
                    tool_dicts=tool_dicts,
                    tool_parser=tool_parser,
                )
            else:
                # For GPT-OSS.
305
306
307
308
                should_include_tools = tool_dicts is not None
                conversation, engine_prompts = self._make_request_with_harmony(
                    request, should_include_tools
                )
309
        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
310
            logger.exception("Error in preprocessing prompt inputs")
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
            return self.create_error_response(e)

        return conversation, engine_prompts

    async def create_chat_completion(
        self,
        request: ChatCompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
        """
        Chat Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        Chat Completion API.
        """
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        # Streaming response
        tokenizer = self.renderer.tokenizer
        assert tokenizer is not None
        reasoning_parser: ReasoningParser | None = None
        try:
            if self.reasoning_parser_cls:
                # Pass the same chat template kwargs as used in tokenization
                chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
                    request.chat_template_kwargs,
                    self.default_chat_template_kwargs,
                )
                reasoning_parser = self.reasoning_parser_cls(
                    tokenizer,
                    chat_template_kwargs=chat_template_kwargs,  # type: ignore[call-arg]
                )
        except RuntimeError as e:
            logger.exception("Error in reasoning parser creation.")
            return self.create_error_response(str(e))
345
346
347
348
349
        result = await self.render_chat_request(request)
        if isinstance(result, ErrorResponse):
            return result

        conversation, engine_prompts = result
350

351
352
353
        request_id = (
            f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
        )
354
355
356
357
358

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

359
360
361
362
363
364
365
366
367
368
        try:
            lora_request = self._maybe_get_adapters(
                request, supports_default_mm_loras=True
            )

            model_name = self.models.model_name(lora_request)
        except (ValueError, TypeError, RuntimeError) as e:
            logger.exception("Error preparing request components")
            return self.create_error_response(e)

369
370
371
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

372
        # Schedule the request and get the result generator.
373
        max_model_len = self.model_config.max_model_len
374
        generators: list[AsyncGenerator[RequestOutput, None]] = []
375
        try:
376
            for i, engine_prompt in enumerate(engine_prompts):
377
                prompt_text = self._extract_prompt_text(engine_prompt)
378

379
380
381
382
383
                # If we are creating sub requests for multiple prompts, ensure that they
                # have unique request ids.
                sub_request_id = (
                    request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
                )
384
385

                max_tokens = get_max_tokens(
386
                    max_model_len,
387
388
389
                    request.max_completion_tokens
                    if request.max_completion_tokens is not None
                    else request.max_tokens,
390
391
                    self._extract_prompt_len(engine_prompt),
                    self.default_sampling_params,
392
                )
393

394
                sampling_params: SamplingParams | BeamSearchParams
395
396
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
397
398
                        max_tokens, self.default_sampling_params
                    )
399
400
                else:
                    sampling_params = request.to_sampling_params(
401
402
403
                        max_tokens,
                        self.default_sampling_params,
                    )
404

405
                self._log_inputs(
406
                    sub_request_id,
407
                    engine_prompt,
408
409
410
                    params=sampling_params,
                    lora_request=lora_request,
                )
411

412
413
414
415
416
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
417
418

                if isinstance(sampling_params, BeamSearchParams):
419
                    generator = self.beam_search(
420
                        prompt=engine_prompt,
421
                        request_id=sub_request_id,
422
                        params=sampling_params,
423
                        lora_request=lora_request,
424
                        trace_headers=trace_headers,
425
426
                    )
                else:
427
428
429
430
                    tok_params = request.build_tok_params(self.model_config)
                    tokenization_kwargs = tok_params.get_encode_kwargs()

                    engine_request = self.input_processor.process_inputs(
431
                        sub_request_id,
432
433
434
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
435
                        tokenization_kwargs=tokenization_kwargs,
436
437
                        trace_headers=trace_headers,
                        priority=request.priority,
438
                        data_parallel_rank=data_parallel_rank,
439
                    )
440
441
442
443
444
445
                    reasoning_ended = None
                    if reasoning_parser:
                        reasoning_ended = reasoning_parser.is_reasoning_end(
                            engine_request.prompt_token_ids or []  # type: ignore[attr-defined]
                        )
                        engine_request.reasoning_ended = reasoning_ended
446
                    generator = self.engine_client.generate(
447
                        engine_request,
448
                        sampling_params,
449
                        sub_request_id,
450
451
452
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
453
454
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
455
                        data_parallel_rank=data_parallel_rank,
456
457
458
                    )

                generators.append(generator)
459
        except ValueError as e:
460
            return self.create_error_response(e)
461

462
        assert len(generators) == 1
463
        (result_generator,) = generators
464

465
466
        if request.stream:
            return self.chat_completion_stream_generator(
467
468
469
470
471
472
473
                request,
                result_generator,
                request_id,
                model_name,
                conversation,
                tokenizer,
                request_metadata,
474
                reasoning_parser,
475
            )
476

477
478
        try:
            return await self.chat_completion_full_generator(
479
480
481
482
483
484
485
                request,
                result_generator,
                request_id,
                model_name,
                conversation,
                tokenizer,
                request_metadata,
486
                reasoning_parser,
487
            )
488
489
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
490
        except ValueError as e:
491
            return self.create_error_response(e)
492
493
494
495

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
496
        return request.messages[-1]["role"]
497

498
    @staticmethod
499
    def _bracket_level(s: str, opening="{", closing="}") -> int:
500
501
502
503
504
505
506
507
508
509
510
511
        """
        Calculate the current level of nested brackets in a given string.
        """
        level = 0
        for char in s:
            if char == opening:
                level += 1
            elif char == closing:
                level -= 1
        return level

    @staticmethod
512
    def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]:
513
514
515
516
517
518
519
520
521
        # remove last '},' of the tool definition stemming from the
        # "name"/"parameters" outer object or closing ']' of the tool list
        # count occurrences of opening and closing curly braces and
        # once level 0 is reached stop outputting text
        # if 0 is reached while parsing the delta_text we know the current
        # tool will finish in this current iteration
        bracket_level = OpenAIServingChat._bracket_level(previous_text)
        updated_delta, passed_zero = "", False
        for c in delta_text:
522
            if c == "{":
523
524
                bracket_level += 1
                passed_zero = bracket_level == 0
525
            elif c == "}":
526
527
528
529
530
531
532
                bracket_level -= 1
                passed_zero = bracket_level == 0

            if bracket_level != 0:
                updated_delta += c
            else:
                # if a comma is reached at level 0 we can stop
533
                if c == ",":
534
535
536
537
538
539
                    break
        return updated_delta, passed_zero

    def extract_tool_call_required_streaming(
        self,
        previous_text: str,
540
        current_text: str | None,
541
542
        delta_text: str,
        function_name_returned: bool,
543
544
        tool_call_idx: int | None = None,
    ) -> tuple[DeltaMessage | None, bool]:
545
546
547
        if current_text is None or current_text == "":
            # if the current text is empty, we cannot parse it
            return None, function_name_returned
548
        try:
549
550
551
552
553
554
            flags = Allow.ALL
            obj, _ = partial_json_loads(current_text, flags)
        except (
            partial_json_parser.core.exceptions.MalformedJSON,
            json.JSONDecodeError,
        ):
555
            logger.debug("not enough tokens to parse into JSON yet")
556
557
558
559
560
561
562
563
564
565
            obj = None

        # check if the current text is a valid array
        # containing a partial tool calling object
        # if not repeat
        if obj is None or not isinstance(obj, list) or not len(obj) > 0:
            function_name_returned = False
            delta_message = None
        else:
            _, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
566
567
                delta_text, previous_text
            )
568
569
570
571
            # take the last tool call from the generated list
            current_tool_call = obj[-1]

            # once parameters have been generated the name is complete as well
572
573
574
            if not finishes_previous_tool and (
                "name" not in current_tool_call or "parameters" not in current_tool_call
            ):
575
576
577
578
579
                function_name_returned = False
                delta_message = None
            else:
                if not function_name_returned:
                    # get partly generated arguments from the latest tool call
580
581
582
                    param_match = re.search(
                        r'.*"parameters":\s*(.*)', current_text, re.DOTALL
                    )
583
584
                    arguments = param_match.group(1) if param_match else ""
                    arguments, _ = OpenAIServingChat._filter_delta_text(
585
586
                        arguments, previous_text
                    )
587
588
589
590

                    # if this iteration finishes a previous tool call but a
                    # new incomplete tool is already generated, take the
                    # previous from the list
591
                    if finishes_previous_tool and "parameters" not in current_tool_call:
592
593
594
                        current_tool_call = obj[-2]

                    function_name_returned = True
595
596
597
                    tool_call_id = make_tool_call_id(
                        id_type=self.tool_call_id_type,
                        func_name=current_tool_call["name"],
598
599
600
601
602
603
604
605
606
607
608
609
610
611
                        idx=tool_call_idx,
                    )
                    delta_message = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                id=tool_call_id,
                                function=DeltaFunctionCall(
                                    name=current_tool_call["name"], arguments=arguments
                                ),
                                index=len(obj) - 1,
                                type="function",
                            )
                        ]
                    )
612
613
614

                else:
                    delta_text, _ = OpenAIServingChat._filter_delta_text(
615
616
                        delta_text, previous_text
                    )
617
618

                    if delta_text != "":
619
620
621
622
623
624
625
626
627
628
629
630
631
                        delta_message = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    function=DeltaFunctionCall(
                                        # OpenAI API returns None
                                        # instead of name every time
                                        name=None,
                                        arguments=delta_text,
                                    ),
                                    index=len(obj) - 1,
                                )
                            ]
                        )
632
633
634
635
636
                    else:
                        delta_message = None

        return delta_message, function_name_returned

637
    async def chat_completion_stream_generator(
638
639
640
641
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
642
        model_name: str,
643
        conversation: list[ConversationMessage],
644
        tokenizer: TokenizerLike,
645
        request_metadata: RequestResponseMetadata,
646
        reasoning_parser: ReasoningParser | None = None,
647
    ) -> AsyncGenerator[str, None]:
648
649
        from vllm.tokenizers.mistral import MistralTokenizer

650
        created_time = int(time.time())
651
        chunk_object_type: Final = "chat.completion.chunk"
652
        first_iteration = True
653
654

        # Send response for each token for each request.n (index)
655
656
657
        num_choices = 1 if request.n is None else request.n
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices
658
        num_prompt_tokens = 0
659
        num_cached_tokens = None
660
661
        if self.use_harmony:
            harmony_parsers = [
662
                get_streamable_parser_for_assistant() for _ in range(num_choices)
663
            ]
664
665
            harmony_tools_streamed = [False] * num_choices
        tools_streamed = [False] * num_choices
666
667
668
669
670
671
672
673
674

        if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
            tool_choice_function_name = request.tool_choice.function.name
        else:
            tool_choice_function_name = None

        # Determine whether tools are in use with "auto" tool choice
        tool_choice_auto = (
            not tool_choice_function_name
675
676
            and self._should_stream_with_auto_tool_parsing(request)
        )
677

678
        all_previous_token_ids: list[list[int]] | None
679
        function_name_returned = [False] * num_choices
680
        if self.tool_call_id_type == "kimi_k2":
681
682
683
            history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
        else:
            history_tool_call_cnt = 0
684

685
686
687
        # Always track previous_texts for comprehensive output logging
        previous_texts = [""] * num_choices

688
689
        # Only one of these will be used, thus previous_texts and
        # all_previous_token_ids will not be used twice in the same iteration.
690
        if tool_choice_auto or reasoning_parser:
691
692
            # These are only required in "auto" tool choice case
            all_previous_token_ids = [[]] * num_choices
693
694
695
            # For reasoning parser and tool call all enabled
            added_content_delta_arr = [False] * num_choices
            reasoning_end_arr = [False] * num_choices
696
            prompt_is_reasoning_end_arr: list[bool | None] = [None] * num_choices
697
        else:
698
            all_previous_token_ids = None
699

700
701
702
        # Prepare the tool parser if it's needed
        try:
            if tool_choice_auto and self.tool_parser:
703
704
705
706
707
                if tokenizer is None:
                    raise ValueError(
                        "Tokenizer not available when `skip_tokenizer_init=True`"
                    )

708
                tool_parsers: list[ToolParser | None] = [
709
710
711
712
                    self.tool_parser(tokenizer)
                ] * num_choices
            else:
                tool_parsers = [None] * num_choices
713
        except Exception as e:
714
            logger.exception("Error in tool parser creation.")
715
            data = self.create_streaming_error_response(e)
716
717
718
719
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return

720
        stream_options = request.stream_options
721
722
723
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
724

725
726
        try:
            async for res in result_generator:
727
728
                if res.prompt_token_ids is not None:
                    num_prompt_tokens = len(res.prompt_token_ids)
729
730
                    if res.encoder_prompt_token_ids is not None:
                        num_prompt_tokens += len(res.encoder_prompt_token_ids)
731

732
733
734
735
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
736
                    num_cached_tokens = res.num_cached_tokens
737
738
                    # Send first response for each request.n (index) with
                    # the role
739
                    role = self.get_chat_request_role(request)
740
741
742

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
743
                    for i in range(num_choices):
744
745
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
746
747
748
749
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
750
                            logprobs=None,
751
752
                            finish_reason=None,
                        )
753
754

                        # return prompt_token_ids at the first chunk ever
755
756
757
758
759
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
760
                            model=model_name,
761
762
763
764
765
766
                            prompt_token_ids=(
                                res.prompt_token_ids
                                if request.return_token_ids
                                else None
                            ),
                        )
767

768
769
770
771
772
                        # if continuous usage stats are requested, add it
                        if include_continuous_usage:
                            chunk.usage = UsageInfo(
                                prompt_tokens=num_prompt_tokens,
                                completion_tokens=0,
773
774
                                total_tokens=num_prompt_tokens,
                            )
775

776
777
778
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

779
780
                    # Send response to echo the input portion of the
                    # last message
781
                    if request.echo:
782
                        last_msg_content: str | list[dict[str, str]] = ""
783
784
785
786
787
                        if (
                            conversation
                            and "content" in conversation[-1]
                            and conversation[-1].get("role") == role
                        ):
788
                            last_msg_content = conversation[-1]["content"] or ""
789
790

                        if last_msg_content:
791
                            for i in range(num_choices):
792
793
794
795
796
797
                                choice_data = ChatCompletionResponseStreamChoice(
                                    index=i,
                                    delta=DeltaMessage(content=last_msg_content),
                                    logprobs=None,
                                    finish_reason=None,
                                )
798
799
800
801
802
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
803
804
                                    model=model_name,
                                )
805
806
807
808
                                if include_continuous_usage:
                                    chunk.usage = UsageInfo(
                                        prompt_tokens=num_prompt_tokens,
                                        completion_tokens=0,
809
810
                                        total_tokens=num_prompt_tokens,
                                    )
811

812
                                data = chunk.model_dump_json(exclude_unset=True)
813
814
815
816
817
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index
818
                    tool_parser = tool_parsers[i]
819

820
                    if (
821
                        reasoning_parser
822
823
824
825
826
827
828
829
                        and res.prompt_token_ids
                        and prompt_is_reasoning_end_arr[i] is None
                    ):
                        # only check once per choice, because prompt_token_ids
                        # are the same for all deltas in that choice
                        prompt_is_reasoning_end_arr[i] = (
                            reasoning_parser.is_reasoning_end(res.prompt_token_ids)
                        )
830
831
832
                    if finish_reason_sent[i]:
                        continue

833
                    if request.logprobs and request.top_logprobs is not None:
834
                        assert output.logprobs is not None, "Did not output logprobs"
835
                        logprobs = self._create_chat_logprobs(
836
837
                            token_ids=output.token_ids,
                            top_logprobs=output.logprobs,
838
                            tokenizer=tokenizer,
839
                            num_output_top_logprobs=request.top_logprobs,
840
                            return_as_token_id=request.return_tokens_as_token_ids,
841
842
843
844
                        )
                    else:
                        logprobs = None

845
846
                    if self.use_harmony:
                        harmony_parser = harmony_parsers[i]
847
                        prev_recipient = harmony_parser.current_recipient
848
849
850

                        # Track accumulated content per token with their state
                        token_states: list[TokenState] = []
851
852
                        for token_id in output.token_ids:
                            harmony_parser.process(token_id)
853
854
855
856
857
858
859
860
861
                            token_delta = harmony_parser.last_content_delta or ""
                            token_states.append(
                                TokenState(
                                    harmony_parser.current_channel,
                                    harmony_parser.current_recipient,
                                    token_delta,
                                )
                            )
                        delta_text = "".join(delta for _, _, delta in token_states)
862
                        cur_channel = harmony_parser.current_channel
863

864
865
866
867
868
                        # handle the case where several tokens where generated at once
                        # including the final token, leading to a delta in the text
                        # but the current channel to be empty (start state)
                        if not cur_channel and delta_text:
                            cur_channel = "final"
869
870
                    else:
                        delta_text = output.text
871

872
873
874
875
876
                    if (
                        not delta_text
                        and not output.token_ids
                        and not previous_num_tokens[i]
                    ):
877
878
879
                        # Chunked prefill case, don't return empty chunks
                        continue

880
                    delta_message: DeltaMessage | None
881

882
                    # just update previous_texts and previous_token_ids
883
                    if tool_choice_auto or reasoning_parser:
884
885
886
887
888
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_text = previous_texts[i]
                        previous_token_ids = all_previous_token_ids[i]
                        current_text = previous_text + delta_text
889
890
                        # avoid the None + list error.
                        if previous_token_ids:
891
                            current_token_ids = previous_token_ids + as_list(
892
893
                                output.token_ids
                            )
894
                        else:
895
                            current_token_ids = as_list(output.token_ids)
896

897
                    if self.use_harmony:
898
899
900
                        delta_message, tools_streamed_flag = (
                            extract_harmony_streaming_delta(
                                harmony_parser=harmony_parser,
901
                                token_states=token_states,
902
903
904
905
906
                                prev_recipient=prev_recipient,
                                include_reasoning=request.include_reasoning,
                            )
                        )
                        harmony_tools_streamed[i] |= tools_streamed_flag
907
                    # handle streaming deltas for tools with named tool_choice
908
                    elif tool_choice_function_name:
909
                        if (
910
                            reasoning_parser
911
912
913
914
915
                            and not reasoning_end_arr[i]
                            and not reasoning_parser.is_reasoning_end(
                                previous_token_ids
                            )
                        ):
916
917
                            assert reasoning_parser is not None
                            delta_message = (
918
                                reasoning_parser.extract_reasoning_streaming(
919
920
921
922
923
924
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output.token_ids,
925
926
                                )
                            )
927
928
929
930
                            # When encountering think end id in delta_token_ids
                            # or think end id in prompt_token_ids
                            # i.e {"enable_thinking": False},
                            # set reasoning status to end.
931
                            # Only keep 'content', remove 'reasoning'.
932
933
934
                            if (
                                reasoning_parser.is_reasoning_end(
                                    as_list(output.token_ids)
935
                                )
936
                                or prompt_is_reasoning_end_arr[i]
937
                            ):
938
                                reasoning_end_arr[i] = True
939
940
941
942
943
944
945
946
                                if delta_message and delta_message.content:
                                    # This need to be added to next `delta_text`
                                    current_text = delta_message.content
                                    delta_message.content = None
                                else:
                                    current_text = ""
                        else:
                            # Just to add remaining `content`
947
                            if reasoning_parser:
948
949
950
                                delta_text = previous_text + delta_text
                                current_text = ""

951
952
                            if function_name_returned[i]:
                                delta_tool_call = DeltaToolCall(
953
954
955
                                    function=DeltaFunctionCall(arguments=delta_text),
                                    index=i,
                                )
956
                            else:
957
958
959
960
961
962
963
964
965
                                # Generate ID based on tokenizer type
                                if isinstance(tokenizer, MistralTokenizer):
                                    tool_call_id = MistralToolCall.generate_random_id()
                                else:
                                    tool_call_id = make_tool_call_id(
                                        id_type=self.tool_call_id_type,
                                        func_name=tool_choice_function_name,
                                        idx=history_tool_call_cnt,
                                    )
966
                                delta_tool_call = DeltaToolCall(
967
                                    id=tool_call_id,
968
969
970
                                    type="function",
                                    function=DeltaFunctionCall(
                                        name=tool_choice_function_name,
971
972
973
974
                                        arguments=delta_text,
                                    ),
                                    index=i,
                                )
975
                                function_name_returned[i] = True
976
                                history_tool_call_cnt += 1
977

978
979
980
981
982
                            delta_message = DeltaMessage(
                                tool_calls=[
                                    delta_tool_call,
                                ]
                            )
983
                            tools_streamed[i] = True
984

985
986
987
988
989
                    elif request.tool_choice == "required":
                        assert previous_texts is not None
                        previous_text = previous_texts[i]
                        current_text = previous_text + delta_text
                        fn_name_returned = function_name_returned[i]
990
991
992
                        output_token_ids = as_list(output.token_ids)

                        if (
993
                            reasoning_parser is not None
994
                            and not reasoning_end_arr[i]
995
                            and prompt_is_reasoning_end_arr[i]
996
997
                        ):
                            reasoning_end_arr[i] = True
998

999
                        if reasoning_parser and not reasoning_end_arr[i]:
1000
                            delta_message = (
1001
                                reasoning_parser.extract_reasoning_streaming(
1002
1003
1004
1005
1006
1007
1008
                                    previous_text,
                                    current_text,
                                    delta_text,
                                    previous_token_ids,
                                    current_token_ids,
                                    output_token_ids,
                                )
1009
                            )
1010
1011
1012
1013
1014
1015
1016
1017
1018
                            if reasoning_parser.is_reasoning_end(output_token_ids):
                                reasoning_end_arr[i] = True
                                if delta_message and delta_message.content:
                                    current_text = delta_message.content
                                    delta_message.content = None
                                else:
                                    # reasoning ended
                                    current_text = ""

1019
                        else:
1020
                            # either finished reasoning or no reasoning at all
1021
                            content = current_text
1022
1023
1024
1025
1026
1027
1028
1029
1030

                            delta_message, function_name_returned[i] = (
                                self.extract_tool_call_required_streaming(
                                    previous_text=previous_text,
                                    current_text=content,
                                    delta_text=delta_text,
                                    function_name_returned=fn_name_returned,
                                    tool_call_idx=history_tool_call_cnt,
                                )
1031
                            )
1032
1033
1034
1035
1036
1037
1038
                            if (
                                delta_message
                                and delta_message.tool_calls
                                and delta_message.tool_calls[0].id is not None
                            ):
                                history_tool_call_cnt += 1
                                tools_streamed[i] = True
1039

1040
1041
                    # handle streaming deltas for tools with "auto" tool choice
                    # and reasoning parser
1042
                    elif tool_choice_auto and reasoning_parser:
1043
1044
1045
                        assert tool_parser is not None
                        assert added_content_delta_arr is not None
                        assert reasoning_end_arr is not None
1046
                        output_token_ids = as_list(output.token_ids)
1047
                        if not reasoning_end_arr[i]:
1048
1049
1050
                            # When encountering think end id in prompt_token_ids
                            # i.e {"enable_thinking": False},
                            # set reasoning status to end.
1051
                            if prompt_is_reasoning_end_arr[i]:
1052
                                reasoning_end_arr[i] = True
1053
                                current_token_ids = output_token_ids
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
                                # Don't update current_text, keep it as is from delta
                            else:
                                delta_message = (
                                    reasoning_parser.extract_reasoning_streaming(
                                        previous_text,
                                        current_text,
                                        delta_text,
                                        previous_token_ids,
                                        current_token_ids,
                                        output_token_ids,
1064
1065
                                    )
                                )
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

                                # When encountering think end id in delta_token_ids,
                                # set reasoning status to end.
                                # Remove the text and token ids related
                                # to 'reasoning'.
                                if reasoning_parser.is_reasoning_end(output_token_ids):
                                    reasoning_end_arr[i] = True
                                    current_token_ids = (
                                        reasoning_parser.extract_content_ids(
                                            output_token_ids
                                        )
                                    )
                                    if delta_message and delta_message.content:
                                        current_text = delta_message.content
                                        delta_message.content = None
                                    else:
                                        current_text = ""
1083
1084

                        # handle tool calls only after reasoning is done,
1085
                        if reasoning_end_arr[i]:
1086
                            delta_token_ids = output_token_ids
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                            # First time to tool call,
                            # add the remaining text and token ids
                            # to delta from previous
                            if not added_content_delta_arr[i]:
                                added_content_delta_arr[i] = True
                                previous_text = ""
                                previous_token_ids = []
                                delta_text = current_text
                                delta_token_ids = current_token_ids

1097
                            delta_message = tool_parser.extract_tool_calls_streaming(
1098
1099
                                previous_text=previous_text,
                                current_text=current_text,
1100
                                delta_text=delta_text,
1101
1102
                                previous_token_ids=previous_token_ids,
                                current_token_ids=current_token_ids,
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
                                delta_token_ids=delta_token_ids,
                                request=request,
                            )
                            if delta_message and delta_message.tool_calls:
                                tools_streamed[i] = True
                    # when only tool calls
                    elif tool_choice_auto:
                        assert tool_parser is not None
                        delta_message = tool_parser.extract_tool_calls_streaming(
                            previous_text=previous_text,
                            current_text=current_text,
                            delta_text=delta_text,
                            previous_token_ids=previous_token_ids,
                            current_token_ids=current_token_ids,
                            delta_token_ids=output.token_ids,
                            request=request,
                        )
1120
1121
                        if delta_message and delta_message.tool_calls:
                            tools_streamed[i] = True
1122

1123
                    # when only reasoning
1124
                    elif reasoning_parser:
1125
1126
1127
1128
1129
1130
1131
                        delta_message = reasoning_parser.extract_reasoning_streaming(
                            previous_text,
                            current_text,
                            delta_text,
                            previous_token_ids,
                            current_token_ids,
                            output.token_ids,
1132
                        )
1133
                    # handle streaming just a content delta
1134
1135
1136
                    else:
                        delta_message = DeltaMessage(content=delta_text)

1137
                    # update the previous values for the next iteration
1138
                    if (tool_choice_auto or reasoning_parser) and not self.use_harmony:
1139
1140
1141
1142
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_texts[i] = current_text
                        all_previous_token_ids[i] = current_token_ids
1143
1144
1145
1146
                    else:
                        # Update for comprehensive logging even in simple case
                        assert previous_texts is not None
                        previous_texts[i] += delta_text
1147

1148
                    # set the previous values for the next iteration
1149
                    previous_num_tokens[i] += len(output.token_ids)
1150
1151
1152
1153
1154
1155

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
1156
1157
1158
1159
1160
1161
1162
                        # NOTE: If return_token_ids is enabled, we still need to
                        # send a chunk with token_ids even if delta_message is None
                        # to ensure all tokens are included in the response
                        if (
                            output.finish_reason is None
                            and not request.return_token_ids
                        ):
1163
                            continue
1164
                        delta_message = DeltaMessage()
1165

1166
1167
                    # Log streaming delta if output logging is enabled
                    if self.enable_log_outputs and self.request_logger:
1168
                        delta_content_parts = []
1169
                        if delta_message.content:
1170
                            delta_content_parts.append(delta_message.content)
1171
1172
                        if delta_message.reasoning:
                            reasoning = delta_message.reasoning
1173
1174
1175
                            delta_content_parts.append(f"[reasoning: {reasoning}]")
                        if delta_message.tool_calls:
                            tool_args = "".join(
1176
1177
                                tc.function.arguments
                                for tc in delta_message.tool_calls
1178
1179
                                if tc.function and tc.function.arguments
                            )
1180
1181
                            if tool_args:
                                delta_content_parts.append(f"[tool_calls: {tool_args}]")
1182

1183
1184
                        if delta_content_parts and self.enable_log_deltas:
                            delta_content = " ".join(delta_content_parts)
1185
1186
1187
                            self.request_logger.log_outputs(
                                request_id=request_id,
                                outputs=delta_content,
1188
                                output_token_ids=as_list(output.token_ids),
1189
1190
1191
1192
1193
                                finish_reason=output.finish_reason,
                                is_streaming=True,
                                delta=True,
                            )

1194
1195
1196
1197
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
1198
                            delta=delta_message,
1199
                            logprobs=logprobs,
1200
                            finish_reason=None,
1201
1202
1203
1204
1205
1206
                            token_ids=(
                                as_list(output.token_ids)
                                if request.return_token_ids
                                else None
                            ),
                        )
1207
1208

                    # if the model is finished generating
1209
                    else:
1210
1211
1212
1213
                        # check for error finish reason and abort streaming
                        # finish_reason='error' indicates a retryable error
                        self._raise_if_error(output.finish_reason, request_id)

1214
1215
1216
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
1217
                        # only happens if we are NOT using structured outputs
1218
                        auto_tools_called = False
1219
                        if tool_parser:
1220
1221
1222
1223
1224
1225
                            auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0
                            index = (
                                len(tool_parser.prev_tool_call_arr) - 1
                                if auto_tools_called
                                else 0
                            )
1226
1227
1228
                        else:
                            index = 0

1229
1230
1231
1232
1233
1234
                        if (
                            self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output
                            )
                            and tool_parser
                        ):
1235
                            latest_delta_len = 0
1236
1237
                            if (
                                isinstance(
1238
                                    delta_message.tool_calls[0].function,
1239
1240
1241
1242
1243
                                    DeltaFunctionCall,
                                )
                            ) and isinstance(
                                delta_message.tool_calls[0].function.arguments, str
                            ):
1244
                                latest_delta_len = len(
1245
1246
                                    delta_message.tool_calls[0].function.arguments
                                )
1247

1248
1249
1250
1251
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
1252
1253
1254
1255
                                    "arguments", {}
                                ),
                                ensure_ascii=False,
                            )
1256

1257
                            # get what we've streamed so far for arguments
1258
                            # for the current tool
1259
1260
                            actual_call = tool_parser.streamed_args_for_tool[index]
                            if latest_delta_len > 0:
1261
                                actual_call = actual_call[:-latest_delta_len]
1262
1263

                            # check to see if there's anything left to stream
1264
                            remaining_call = expected_call.replace(actual_call, "", 1)
1265
                            # set that as a delta message
1266
1267
                            delta_message = self._create_remaining_args_delta(
                                delta_message, remaining_call, index
1268
                            )
1269

1270
                        # Send the finish response for each request.n only once
1271
1272
1273
1274
                        # In OpenAI's API, when a tool is called, the
                        # finish_reason is:
                        # "tool_calls" for "auto" or "required" tool calls,
                        # and "stop" for named tool calls.
1275
1276
                        if (
                            auto_tools_called
1277
                            or (tools_streamed[i] and not tool_choice_function_name)
1278
1279
                            or (self.use_harmony and harmony_tools_streamed[i])
                        ):
1280
1281
                            finish_reason_ = "tool_calls"
                        else:
1282
1283
1284
                            finish_reason_ = (
                                output.finish_reason if output.finish_reason else "stop"
                            )
1285
1286
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
1287
                            delta=delta_message,
1288
                            logprobs=logprobs,
1289
                            finish_reason=finish_reason_,
1290
                            stop_reason=output.stop_reason,
1291
1292
1293
1294
1295
1296
                            token_ids=(
                                as_list(output.token_ids)
                                if request.return_token_ids
                                else None
                            ),
                        )
1297

1298
                        finish_reason_sent[i] = True
1299

1300
                    choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
1301
1302
1303
1304
1305
                    chunk = ChatCompletionStreamResponse(
                        id=request_id,
                        object=chunk_object_type,
                        created=created_time,
                        choices=[choice_data],
1306
1307
                        model=model_name,
                    )
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317

                    # handle usage stats if requested & if continuous
                    if include_continuous_usage:
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=num_prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=num_prompt_tokens + completion_tokens,
                        )

1318
                    data = chunk.model_dump_json(exclude_unset=True)
1319
1320
                    yield f"data: {data}\n\n"

1321
1322
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
1323
1324
            if include_usage:
                completion_tokens = sum(previous_num_tokens)
1325
1326
1327
1328
1329
                final_usage = UsageInfo(
                    prompt_tokens=num_prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=num_prompt_tokens + completion_tokens,
                )
1330
1331
                if self.enable_prompt_tokens_details and num_cached_tokens:
                    final_usage.prompt_tokens_details = PromptTokenUsageInfo(
1332
1333
                        cached_tokens=num_cached_tokens
                    )
1334
1335
1336
1337
1338
1339
1340

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
1341
1342
1343
1344
1345
                    usage=final_usage,
                )
                final_usage_data = final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True
                )
1346
                yield f"data: {final_usage_data}\n\n"
1347

1348
1349
1350
1351
1352
            # report to FastAPI middleware aggregate usage across all choices
            num_completion_tokens = sum(previous_num_tokens)
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=num_completion_tokens,
1353
1354
1355
1356
1357
1358
1359
1360
1361
                total_tokens=num_prompt_tokens + num_completion_tokens,
            )

            # Log complete streaming response if output logging is enabled
            if self.enable_log_outputs and self.request_logger:
                # Log the complete response for each choice
                for i in range(num_choices):
                    full_text = (
                        previous_texts[i]
1362
1363
                        if previous_texts and i < len(previous_texts)
                        else f"<streaming_complete: {previous_num_tokens[i]} tokens>"
1364
1365
1366
1367
                    )
                    self.request_logger.log_outputs(
                        request_id=request_id,
                        outputs=full_text,
1368
                        output_token_ids=None,  # Consider also logging all token IDs
1369
1370
1371
1372
                        finish_reason="streaming_complete",
                        is_streaming=True,
                        delta=False,
                    )
1373

1374
1375
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
1376
        except Exception as e:
1377
            logger.exception("Error in chat completion stream generator.")
1378
            data = self.create_streaming_error_response(e)
1379
            yield f"data: {data}\n\n"
1380
1381
1382
1383
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
1384
1385
1386
1387
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
1388
        model_name: str,
1389
        conversation: list[ConversationMessage],
1390
        tokenizer: TokenizerLike,
1391
        request_metadata: RequestResponseMetadata,
1392
        reasoning_parser: ReasoningParser | None = None,
1393
    ) -> ErrorResponse | ChatCompletionResponse:
1394
1395
        from vllm.tokenizers.mistral import MistralTokenizer

1396
        created_time = int(time.time())
1397
        final_res: RequestOutput | None = None
1398

1399
1400
1401
1402
1403
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
1404
        except ValueError as e:
1405
            return self.create_error_response(e)
1406

1407
1408
        assert final_res is not None

1409
        choices: list[ChatCompletionResponseChoice] = []
1410
        if self.tool_call_id_type == "kimi_k2":
1411
1412
1413
            history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
        else:
            history_tool_call_cnt = 0
1414

1415
1416
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
1417
1418
1419
            # check for error finish reason and raise GenerationError
            # finish_reason='error' indicates a retryable request-level internal error
            self._raise_if_error(output.finish_reason, request_id)
1420
            token_ids = output.token_ids
1421
            out_logprobs = output.logprobs
1422
            tool_call_info = None
1423

1424
1425
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
1426
                logprobs = self._create_chat_logprobs(
1427
                    token_ids=token_ids,
1428
                    top_logprobs=out_logprobs,
1429
                    num_output_top_logprobs=request.top_logprobs,
1430
                    tokenizer=tokenizer,
1431
                    return_as_token_id=request.return_tokens_as_token_ids,
1432
1433
1434
                )
            else:
                logprobs = None
1435
1436

            if self.use_harmony:
1437
                reasoning, content, _ = parse_chat_output(token_ids)
1438
                if not request.include_reasoning:
1439
                    reasoning = None
1440

1441
                if self.tool_parser is not None:
1442
1443
1444
1445
1446
                    if tokenizer is None:
                        raise ValueError(
                            "Tokenizer not available when `skip_tokenizer_init=True`"
                        )

1447
1448
1449
1450
1451
1452
1453
                    tool_parser = self.tool_parser(tokenizer)
                    # NOTE: We use token_ids for openai tool parser
                    tool_call_info = tool_parser.extract_tool_calls(
                        "",
                        request=request,
                        token_ids=token_ids,  # type: ignore
                    )
1454
                    content = tool_call_info.content
1455
1456
                    message = ChatMessage(
                        role=role,
1457
                        reasoning=reasoning,
1458
1459
1460
1461
1462
1463
                        content=content,
                        tool_calls=tool_call_info.tool_calls,
                    )
                else:
                    message = ChatMessage(
                        role=role,
1464
                        reasoning=reasoning,
1465
1466
                        content=content,
                    )
1467
1468
1469
1470
1471

                choice_data = ChatCompletionResponseChoice(
                    index=output.index,
                    message=message,
                    logprobs=logprobs,
1472
1473
1474
1475
1476
1477
1478
                    finish_reason=(
                        "tool_calls"
                        if (tool_call_info is not None and tool_call_info.tools_called)
                        else output.finish_reason
                        if output.finish_reason
                        else "stop"
                    ),
1479
                    stop_reason=output.stop_reason,
1480
1481
1482
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
1483
1484
1485
                )
                choices.append(choice_data)
                continue
1486

1487
            if reasoning_parser:
1488
1489
                # If the reasoning parser is enabled,
                # tool calls are extracted exclusively from the content.
1490
                reasoning, content = reasoning_parser.extract_reasoning(
1491
1492
                    output.text, request=request
                )
1493
                if not request.include_reasoning:
1494
                    reasoning = None
1495
            else:
1496
                reasoning = None
1497
                content = output.text
1498

1499
            auto_tools_called = False
1500
1501
            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
            tool_calls, content = self._parse_tool_calls_from_content(
                request=request,
                tokenizer=tokenizer,
                content=content,
                enable_auto_tools=self.enable_auto_tools,
                tool_parser_cls=self.tool_parser,
            )
            tool_call_class = (
                MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
            )
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
            if self.use_harmony:
                # Harmony models already have parsed content and tool_calls
                # through parse_chat_output. Respect its output directly.
                message = ChatMessage(
                    role=role,
                    reasoning=reasoning,
                    content=content,
                    tool_calls=tool_calls if tool_calls else [],
                )

            elif (not self.enable_auto_tools or not self.tool_parser) and (
1523
1524
1525
                not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
                and request.tool_choice != "required"
            ):
1526
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1527

1528
1529
1530
1531
            elif (
                request.tool_choice
                and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
            ):
1532
                assert tool_calls is not None and len(tool_calls) > 0
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
                tool_call_class_items = []
                for idx, tc in enumerate(tool_calls):
                    # Use native ID if available (e.g., Kimi K2),
                    # otherwise generate ID with correct id_type
                    if tc.id:
                        tool_call_class_items.append(
                            tool_call_class(id=tc.id, function=tc)
                        )
                    else:
                        # Generate ID using the correct format (kimi_k2 or random),
                        # but leave it to the class if it's Mistral to preserve
                        # 9-char IDs
                        if isinstance(tokenizer, MistralTokenizer):
                            tool_call_class_items.append(tool_call_class(function=tc))
                        else:
                            generated_id = make_tool_call_id(
                                id_type=self.tool_call_id_type,
                                func_name=tc.name,
1551
                                idx=history_tool_call_cnt,
1552
1553
1554
1555
1556
                            )
                            tool_call_class_items.append(
                                tool_call_class(id=generated_id, function=tc)
                            )
                    history_tool_call_cnt += 1
1557
1558
                message = ChatMessage(
                    role=role,
1559
                    reasoning=reasoning,
1560
                    content="",
1561
                    tool_calls=tool_call_class_items,
1562
                )
1563

1564
            elif request.tool_choice and request.tool_choice == "required":
1565
1566
                tool_call_class_items = []
                assert tool_calls is not None and len(tool_calls) > 0
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
                for idx, tool_call in enumerate(tool_calls):
                    # Use native ID if available,
                    # otherwise generate ID with correct id_type
                    if tool_call.id:
                        tool_call_class_items.append(
                            tool_call_class(id=tool_call.id, function=tool_call)
                        )
                    else:
                        # Generate ID using the correct format (kimi_k2 or random),
                        # but leave it to the class if it's Mistral to preserve
                        # 9-char IDs
                        if isinstance(tokenizer, MistralTokenizer):
                            tool_call_class_items.append(
                                tool_call_class(function=tool_call)
                            )
                        else:
                            generated_id = make_tool_call_id(
1584
1585
                                id_type=self.tool_call_id_type,
                                func_name=tool_call.name,
1586
                                idx=history_tool_call_cnt,
1587
1588
1589
1590
                            )
                            tool_call_class_items.append(
                                tool_call_class(id=generated_id, function=tool_call)
                            )
1591
                    history_tool_call_cnt += 1
1592
1593
1594
                message = ChatMessage(
                    role=role,
                    content="",
1595
                    tool_calls=tool_call_class_items,
1596
                    reasoning=reasoning,
1597
                )
1598

1599
1600
            # if the request doesn't use tool choice
            # OR specifies to not use a tool
1601
            elif not request.tool_choice or request.tool_choice == "none":
1602
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1603
1604

            # handle when there are tools and tool choice is auto
1605
1606
1607
1608
1609
1610
            elif (
                request.tools
                and (request.tool_choice == "auto" or request.tool_choice is None)
                and self.enable_auto_tools
                and self.tool_parser
            ):
1611
1612
1613
                # In the OpenAI API the finish_reason is "tools_called"
                # if the tool choice is auto and the model produced a tool
                # call. The same is not true for named function calls
1614
1615
                auto_tools_called = tool_calls is not None and len(tool_calls) > 0
                if tool_calls:
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
                    tool_call_items = []
                    for idx, tc in enumerate(tool_calls):
                        # Use native ID if available (e.g., Kimi K2),
                        # otherwise generate ID with correct id_type
                        if tc.id:
                            tool_call_items.append(
                                tool_call_class(id=tc.id, function=tc)
                            )
                        else:
                            # Generate ID using the correct format (kimi_k2 or random),
                            # but leave it to the class if it's Mistral to preserve
                            # 9-char IDs
                            if isinstance(tokenizer, MistralTokenizer):
                                tool_call_items.append(tool_call_class(function=tc))
                            else:
                                generated_id = make_tool_call_id(
                                    id_type=self.tool_call_id_type,
                                    func_name=tc.name,
1634
                                    idx=history_tool_call_cnt,
1635
1636
1637
1638
1639
                                )
                                tool_call_items.append(
                                    tool_call_class(id=generated_id, function=tc)
                                )
                        history_tool_call_cnt += 1
1640
1641
                    message = ChatMessage(
                        role=role,
1642
                        reasoning=reasoning,
1643
                        content=content,
1644
                        tool_calls=tool_call_items,
1645
                    )
1646
1647
1648
1649

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
1650
1651
1652
1653
                    ret_content = content

                    # try to use content return from tool parser first,
                    # tool parser may do some modify for the content.
1654
1655
                    if content and len(content) > 0:
                        ret_content = content
1656
1657
                    message = ChatMessage(
                        role=role,
1658
                        reasoning=reasoning,
1659
1660
                        content=ret_content,
                    )
1661
1662
1663
1664
1665
1666

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
1667
1668
                    "completion."
                )
1669
                message = ChatMessage(role=role, reasoning=reasoning, content=content)
1670
1671
1672
1673
1674
1675
1676
1677
            # In OpenAI's API, when a tool is called, the finish_reason is:
            # "tool_calls" for "auto" or "required" tool calls,
            # and "stop" for named tool calls.
            is_finish_reason_tool_calls = auto_tools_called or (
                request.tool_choice
                and request.tool_choice == "required"
                and output.finish_reason == "stop"
            )
1678

1679
1680
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
1681
                message=message,
1682
                logprobs=logprobs,
1683
1684
1685
1686
1687
                finish_reason="tool_calls"
                if is_finish_reason_tool_calls
                else output.finish_reason
                if output.finish_reason
                else "stop",
1688
                stop_reason=output.stop_reason,
1689
1690
1691
                token_ids=(
                    as_list(output.token_ids) if request.return_token_ids else None
                ),
1692
            )
1693
            choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
1694

1695
1696
            choices.append(choice_data)

1697
        if request.echo:
1698
            last_msg_content: str | list[dict[str, str]] = ""
1699
1700
1701
1702
1703
            if (
                conversation
                and "content" in conversation[-1]
                and conversation[-1].get("role") == role
            ):
1704
                last_msg_content = conversation[-1]["content"] or ""
1705
            if isinstance(last_msg_content, list):
1706
                last_msg_content = "\n".join(msg["text"] for msg in last_msg_content)
1707
1708

            for choice in choices:
1709
                full_message = last_msg_content + (choice.message.content or "")
1710
1711
                choice.message.content = full_message

1712
        assert final_res.prompt_token_ids is not None
1713
        num_prompt_tokens = len(final_res.prompt_token_ids)
1714
1715
        if final_res.encoder_prompt_token_ids is not None:
            num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
1716
        num_generated_tokens = sum(
1717
1718
1719
1720
1721
1722
1723
            len(output.token_ids) for output in final_res.outputs
        )
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
1724
1725
        if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
            usage.prompt_tokens_details = PromptTokenUsageInfo(
1726
1727
                cached_tokens=final_res.num_cached_tokens
            )
1728
1729
1730

        request_metadata.final_usage_info = usage

1731
1732
1733
1734
1735
1736
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
1737
            prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
1738
1739
1740
            prompt_token_ids=(
                final_res.prompt_token_ids if request.return_token_ids else None
            ),
Robert Shaw's avatar
Robert Shaw committed
1741
            kv_transfer_params=final_res.kv_transfer_params,
1742
1743
        )

1744
1745
1746
1747
1748
1749
1750
1751
1752
        # Log complete response if output logging is enabled
        if self.enable_log_outputs and self.request_logger:
            for choice in choices:
                output_text = ""
                if choice.message.content:
                    output_text = choice.message.content
                elif choice.message.tool_calls:
                    # For tool calls, log the function name and arguments
                    tool_call_descriptions = []
1753
1754
1755
1756
1757
                    for tc in choice.message.tool_calls:  # type: ignore
                        function_call: FunctionCall = tc.function  # type: ignore
                        tool_call_descriptions.append(
                            f"{function_call.name}({function_call.arguments})"
                        )
1758
1759
1760
1761
1762
1763
1764
                    tool_calls_str = ", ".join(tool_call_descriptions)
                    output_text = f"[tool_calls: {tool_calls_str}]"

                if output_text:
                    # Get the corresponding output token IDs
                    output_token_ids = None
                    if choice.index < len(final_res.outputs):
1765
                        output_token_ids = final_res.outputs[choice.index].token_ids
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775

                    self.request_logger.log_outputs(
                        request_id=request_id,
                        outputs=output_text,
                        output_token_ids=output_token_ids,
                        finish_reason=choice.finish_reason,
                        is_streaming=False,
                        delta=False,
                    )

1776
        return response
1777
1778

    def _get_top_logprobs(
1779
1780
        self,
        logprobs: dict[int, Logprob],
1781
        top_logprobs: int | None,
1782
        tokenizer: TokenizerLike | None,
1783
1784
        should_return_as_token_id: bool,
    ) -> list[ChatCompletionLogProb]:
1785
        return [
1786
            ChatCompletionLogProb(
1787
1788
1789
1790
1791
1792
1793
1794
                token=(
                    token := self._get_decoded_token(
                        p[1],
                        p[0],
                        tokenizer,
                        return_as_token_id=should_return_as_token_id,
                    )
                ),
1795
1796
                logprob=max(p[1].logprob, -9999.0),
                bytes=list(token.encode("utf-8", errors="replace")),
1797
1798
            )
            for i, p in enumerate(logprobs.items())
1799
            if (top_logprobs and i < top_logprobs or top_logprobs == -1)
1800
1801
1802
1803
1804
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
1805
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
1806
        tokenizer: TokenizerLike | None,
1807
1808
        num_output_top_logprobs: int | None = None,
        return_as_token_id: bool | None = None,
1809
1810
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
1811
        logprobs_content: list[ChatCompletionLogProbsContent] = []
1812

1813
1814
1815
1816
1817
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
1818
1819
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
1820
            if step_top_logprobs is None or step_top_logprobs.get(token_id) is None:
1821
                if should_return_as_token_id:
1822
                    token = f"token_id:{token_id}"
1823
                else:
1824
1825
                    if tokenizer is None:
                        raise ValueError(
1826
                            "Unable to get tokenizer because `skip_tokenizer_init=True`"
1827
1828
                        )

1829
                    token = tokenizer.decode(token_id)
1830

1831
1832
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
1833
                        token=token,
1834
                        bytes=list(token.encode("utf-8", errors="replace")),
1835
1836
                    )
                )
1837
            else:
1838
1839
1840
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

1841
1842
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
1843
                        token=self._get_decoded_token(
1844
1845
1846
                            step_token,
                            token_id,
                            tokenizer,
1847
                            should_return_as_token_id,
1848
1849
                        ),
                        logprob=max(step_token.logprob, -9999.0),
1850
1851
1852
1853
1854
                        bytes=(
                            None
                            if step_decoded is None
                            else list(step_decoded.encode("utf-8", errors="replace"))
                        ),
1855
                        top_logprobs=self._get_top_logprobs(
1856
1857
1858
1859
1860
1861
1862
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                            should_return_as_token_id,
                        ),
                    )
                )
1863
1864

        return ChatCompletionLogProbs(content=logprobs_content)
1865

1866
    def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest):
1867
1868
1869
1870
1871
1872
1873
1874
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
1875
1876
1877
1878
1879
1880
        return (
            request.tools
            and self.tool_parser
            and self.enable_auto_tools
            and request.tool_choice in ["auto", None]
        )
1881
1882
1883

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
1884
        delta_message: DeltaMessage | None,
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
1896
            output.finish_reason is not None
1897
1898
1899
1900
1901
            and self.enable_auto_tools
            and self.tool_parser
            and delta_message
            and delta_message.tool_calls
            and delta_message.tool_calls[0]
1902
1903
1904
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
        )
1905

1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
    @staticmethod
    def _create_remaining_args_delta(
        delta_message: DeltaMessage,
        remaining_call: str,
        index: int,
    ) -> DeltaMessage:
        """
        Create a delta message for remaining tool arguments, preserving
        id/type/name from the original delta.
        """
        original_tc = next(
            (tc for tc in delta_message.tool_calls if tc.index == index),
            None,
        )
        original_fn = original_tc.function if original_tc else None
        return DeltaMessage(
            tool_calls=[
                DeltaToolCall(
                    index=index,
                    id=original_tc.id if original_tc else None,
                    type=original_tc.type if original_tc else None,
                    function=DeltaFunctionCall(
                        name=original_fn.name if original_fn else None,
                        arguments=remaining_call,
                    ),
                )
            ]
        )

1935
1936
1937
    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
1938
        should_include_tools: bool = True,
1939
1940
1941
    ):
        messages: list[OpenAIMessage] = []

1942
1943
1944
        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
1945
        maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
1946

1947
1948
1949
1950
1951
1952
1953
1954
        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
        sys_msg = get_system_message(
            reasoning_effort=request.reasoning_effort,
            browser_description=None,
1955
            python_description=None,
1956
            with_custom_tools=should_include_tools,
1957
        )
1958
1959
1960
        messages.append(sys_msg)

        # Add developer message.
1961
1962
        if request.tools:
            dev_msg = get_developer_message(
1963
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
1964
1965
            )
            messages.append(dev_msg)
1966
1967

        # Add user message.
1968
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
1969
1970
1971

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
1972
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1973
1974
1975
1976
1977

        # Add cache_salt if provided in the request
        if request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1978
        return messages, [engine_prompt]