serving_chat.py 43.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import json
5
import time
6
7
8
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
9

10
from fastapi import Request
11

12
from vllm.config import ModelConfig
13
from vllm.engine.protocol import EngineClient
14
15
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
                                         ConversationMessage)
16
from vllm.entrypoints.logger import RequestLogger
17
from vllm.entrypoints.openai.protocol import (
18
19
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
20
    ChatCompletionRequest, ChatCompletionResponse,
21
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
22
    ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
23
24
    DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
    RequestResponseMetadata, ToolCall, UsageInfo)
25
26
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
                                                       ReasoningParserManager)
27
28
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
                                                    clamp_prompt_logprobs)
29
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
30
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
31
32
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
    MistralToolCall)
33
from vllm.logger import init_logger
34
from vllm.outputs import CompletionOutput, RequestOutput
35
from vllm.sampling_params import BeamSearchParams, SamplingParams
36
from vllm.sequence import Logprob
37
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
38
39
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
                                                truncate_tool_call_ids)
40
41
42
43
44
45

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

46
47
48
49
    def __init__(
        self,
        engine_client: EngineClient,
        model_config: ModelConfig,
50
        models: OpenAIServingModels,
51
52
53
54
55
56
        response_role: str,
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
        return_tokens_as_token_ids: bool = False,
57
58
        enable_reasoning: bool = False,
        reasoning_parser: Optional[str] = None,
59
60
61
62
        enable_auto_tools: bool = False,
        tool_parser: Optional[str] = None,
        enable_prompt_tokens_details: bool = False,
    ) -> None:
63
        super().__init__(engine_client=engine_client,
64
                         model_config=model_config,
65
                         models=models,
66
67
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
68

69
        self.response_role = response_role
70
71
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
72

73
74
75
76
77
78
79
80
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
        if self.enable_auto_tools:
            logger.info(
                "\"auto\" tool choice has been enabled please note that while"
                " the parallel_tool_calls client option is preset for "
                "compatibility reasons, it will be ignored.")

81
82
83
84
85
86
87
88
89
90
91
92
        self.enable_reasoning: bool = enable_reasoning
        self.reasoning_parser: Optional[Callable[[AnyTokenizer],
                                                 ReasoningParser]] = None
        if self.enable_reasoning:
            try:
                self.reasoning_parser = (
                    ReasoningParserManager.get_reasoning_parser(
                        reasoning_parser))
            except Exception as e:
                raise TypeError("Error: --enable-reasoning requires "
                                f"reasoning_parser:'{reasoning_parser}' "
                                "which has not been registered") from e
93
94
        self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
        if self.enable_auto_tools:
95
            try:
96
97
98
99
100
                if (tool_parser == "pythonic" and
                        model_config.model.startswith("meta-llama/Llama-3.2")):
                    logger.warning(
                        "Llama3.2 models may struggle to emit valid pythonic"
                        " tool calls")
101
102
103
                self.tool_parser = ToolParserManager.get_tool_parser(
                    tool_parser)
            except Exception as e:
104
                raise TypeError("Error: --enable-auto-tool-choice requires "
105
106
                                f"tool_parser:'{tool_parser}' which has not "
                                "been registered") from e
107

108
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
109
110
111
        self.default_sampling_params = (
            self.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
112
            logger.info("Overwriting default chat sampling param with: %s",
113
                        self.default_sampling_params)
114

115
    async def create_chat_completion(
116
117
        self,
        request: ChatCompletionRequest,
118
119
120
        raw_request: Optional[Request] = None,
    ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
               ErrorResponse]:
121
122
        """
        Chat Completion API similar to OpenAI's API.
123

124
125
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
126
        Chat Completion API.
127
128
129
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
130
            logger.error("Error with model %s", error_check_ret)
131
132
            return error_check_ret

133
134
135
136
137
138
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

139
        try:
140
141
142
143
144
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

145
            model_name = self._get_model_name(request.model, lora_request)
146

147
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
148

149
150
151
152
153
154
155
156
            tool_parser = self.tool_parser

            # validation for OpenAI tools
            # tool_choice = "required" is not supported
            if request.tool_choice == "required":
                return self.create_error_response(
                    "tool_choice = \"required\" is not supported!")

157
            if isinstance(tokenizer, MistralTokenizer):
158
159
160
                # because of issues with pydantic we need to potentially
                # re-serialize the tool_calls field of the request
                # for more info: see comment in `maybe_serialize_tool_calls`
161
                maybe_serialize_tool_calls(request)
162
                truncate_tool_call_ids(request)
163

164
165
166
167
168
169
170
171
172
            if (request.tool_choice == "auto" and
                    not (self.enable_auto_tools and tool_parser is not None)
                    and not isinstance(tokenizer, MistralTokenizer)):
                # for hf tokenizers, "auto" tools requires
                # --enable-auto-tool-choice and --tool-call-parser
                return self.create_error_response(
                    "\"auto\" tool choice requires "
                    "--enable-auto-tool-choice and --tool-call-parser to be set"
                )
173

174
175
176
177
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

178
179
180
181
182
183
184
185
186
            (
                conversation,
                request_prompts,
                engine_prompts,
            ) = await self._preprocess_chat(
                request,
                tokenizer,
                request.messages,
                chat_template=request.chat_template or self.chat_template,
187
                chat_template_content_format=self.chat_template_content_format,
188
189
190
191
192
193
194
195
196
197
198
                add_generation_prompt=request.add_generation_prompt,
                continue_final_message=request.continue_final_message,
                tool_dicts=tool_dicts,
                documents=request.documents,
                chat_template_kwargs=request.chat_template_kwargs,
                tool_parser=tool_parser,
                truncate_prompt_tokens=request.truncate_prompt_tokens,
                add_special_tokens=request.add_special_tokens,
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
199
200
            return self.create_error_response(str(e))

201
202
        request_id = "chatcmpl-" \
                     f"{self._base_request_id(raw_request, request.request_id)}"
203
204
205
206
207

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

208
        # Schedule the request and get the result generator.
209
        generators: list[AsyncGenerator[RequestOutput, None]] = []
210
        try:
211
212
213
214
215
216
            for i, engine_prompt in enumerate(engine_prompts):
                sampling_params: Union[SamplingParams, BeamSearchParams]
                default_max_tokens = self.max_model_len - len(
                    engine_prompt["prompt_token_ids"])
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
217
                        default_max_tokens, self.default_sampling_params)
218
219
                else:
                    sampling_params = request.to_sampling_params(
220
                        default_max_tokens,
221
                        self.model_config.logits_processor_pattern,
222
                        self.default_sampling_params)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

                self._log_inputs(request_id,
                                 request_prompts[i],
                                 params=sampling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)

                trace_headers = (None if raw_request is None else await
                                 self._get_trace_headers(raw_request.headers))

                if isinstance(sampling_params, BeamSearchParams):
                    generator = self.engine_client.beam_search(
                        prompt=engine_prompt,
                        request_id=request_id,
                        params=sampling_params,
                    )
                else:
                    generator = self.engine_client.generate(
                        engine_prompt,
                        sampling_params,
                        request_id,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        prompt_adapter_request=prompt_adapter_request,
                        priority=request.priority,
                    )

                generators.append(generator)
251
        except ValueError as e:
252
            # TODO: Use a vllm-specific Validation Error
253
254
            return self.create_error_response(str(e))

255
256
257
        assert len(generators) == 1
        result_generator, = generators

258
259
260
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
261
262
                request, result_generator, request_id, model_name,
                conversation, tokenizer, request_metadata)
263

264
265
        try:
            return await self.chat_completion_full_generator(
266
267
                request, result_generator, request_id, model_name,
                conversation, tokenizer, request_metadata)
268
269
270
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
271
272
273
274

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
275
        return request.messages[-1]["role"]
276
277

    async def chat_completion_stream_generator(
278
279
280
281
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
282
        model_name: str,
283
        conversation: list[ConversationMessage],
284
        tokenizer: AnyTokenizer,
285
        request_metadata: RequestResponseMetadata,
286
    ) -> AsyncGenerator[str, None]:
287
        created_time = int(time.time())
288
        chunk_object_type: Final = "chat.completion.chunk"
289
        first_iteration = True
290
291

        # Send response for each token for each request.n (index)
292
293
294
        num_choices = 1 if request.n is None else request.n
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices
295
        num_prompt_tokens = 0
296
        num_cached_tokens = None
297
298
299
300
301
302
303
304
305
306
307

        if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
            tool_choice_function_name = request.tool_choice.function.name
        else:
            tool_choice_function_name = None

        # Determine whether tools are in use with "auto" tool choice
        tool_choice_auto = (
            not tool_choice_function_name
            and self._should_stream_with_auto_tool_parsing(request))

308
309
310
        should_stream_with_reasoning_parsing = (
            self._should_stream_with_reasoning_parsing(request))

311
        all_previous_token_ids: Optional[list[list[int]]]
312
313
314
315

        # Only one of these will be used, thus previous_texts and
        # all_previous_token_ids will not be used twice in the same iteration.
        if tool_choice_auto or should_stream_with_reasoning_parsing:
316
317
318
319
320
321
            # These are only required in "auto" tool choice case
            previous_texts = [""] * num_choices
            all_previous_token_ids = [[]] * num_choices
        else:
            previous_texts, all_previous_token_ids = None, None

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        try:
            # There is no need to check if the reasoning_parser is None
            # because the should_stream_with_reasoning_parsing check
            # already ensures that the reasoning_parser is not None.
            # but the pre-commit hook requires it.
            if should_stream_with_reasoning_parsing and \
                self.reasoning_parser is not None:
                reasoning_parser = self.reasoning_parser(tokenizer)
        except RuntimeError as e:
            logger.exception("Error in reasoning parser creation.")
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return

337
338
339
        # Prepare the tool parser if it's needed
        try:
            if tool_choice_auto and self.tool_parser:
340
                tool_parsers: list[Optional[ToolParser]] = [
341
342
343
344
                    self.tool_parser(tokenizer)
                ] * num_choices
            else:
                tool_parsers = [None] * num_choices
345
        except Exception as e:
346
            logger.exception("Error in tool parser creation.")
347
348
349
350
351
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return

352
353
354
355
356
357
358
359
        stream_options = request.stream_options
        if stream_options:
            include_usage = stream_options.include_usage
            include_continuous_usage = include_usage and \
                                       stream_options.continuous_usage_stats
        else:
            include_usage, include_continuous_usage = False, False

360
361
        try:
            async for res in result_generator:
362
363
                if res.prompt_token_ids is not None:
                    num_prompt_tokens = len(res.prompt_token_ids)
364
365
                    if res.encoder_prompt_token_ids is not None:
                        num_prompt_tokens += len(res.encoder_prompt_token_ids)
366

367
368
369
370
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
371
                    num_cached_tokens = res.num_cached_tokens
372
373
                    # Send first response for each request.n (index) with
                    # the role
374
                    role = self.get_chat_request_role(request)
375
376
377

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
378
                    for i in range(num_choices):
379
380
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
381
382
383
384
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
385
386
387
388
389
390
391
392
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
393

394
395
396
397
398
399
                        # if continuous usage stats are requested, add it
                        if include_continuous_usage:
                            chunk.usage = UsageInfo(
                                prompt_tokens=num_prompt_tokens,
                                completion_tokens=0,
                                total_tokens=num_prompt_tokens)
400

401
402
403
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

404
405
                    # Send response to echo the input portion of the
                    # last message
406
                    if request.echo:
407
                        last_msg_content: Union[str, list[dict[str, str]]] = ""
408
409
410
                        if conversation and "content" in conversation[
                                -1] and conversation[-1].get("role") == role:
                            last_msg_content = conversation[-1]["content"] or ""
411
412

                        if last_msg_content:
413
                            for i in range(num_choices):
414
415
416
417
418
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
419
                                        logprobs=None,
420
                                        finish_reason=None))
421
422
423
424
425
426
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    model=model_name)
427
428
429
430
431
                                if include_continuous_usage:
                                    chunk.usage = UsageInfo(
                                        prompt_tokens=num_prompt_tokens,
                                        completion_tokens=0,
                                        total_tokens=num_prompt_tokens)
432

433
434
435
436
437
438
439
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index
440
                    tool_parser = tool_parsers[i]
441
442
443
444

                    if finish_reason_sent[i]:
                        continue

445
                    if request.logprobs and request.top_logprobs is not None:
446
                        assert output.logprobs is not None, (
447
                            "Did not output logprobs")
448
                        logprobs = self._create_chat_logprobs(
449
450
                            token_ids=output.token_ids,
                            top_logprobs=output.logprobs,
451
                            tokenizer=tokenizer,
452
                            num_output_top_logprobs=request.top_logprobs,
453
454
455
456
                        )
                    else:
                        logprobs = None

457
                    delta_text = output.text
458
459
460
461
462
463

                    if not delta_text and not output.token_ids and \
                        not previous_num_tokens[i]:
                        # Chunked prefill case, don't return empty chunks
                        continue

464
                    delta_message: Optional[DeltaMessage]
465

466
                    # handle streaming deltas for tools with named tool_choice
467
                    if tool_choice_function_name:
468
                        delta_message = DeltaMessage(tool_calls=[
469
                            DeltaToolCall(function=DeltaFunctionCall(
470
                                name=tool_choice_function_name,
471
472
                                arguments=delta_text),
                                          index=i)
473
                        ])
474
475

                    # handle streaming deltas for tools with "auto" tool choice
476
477
478
479
480
481
482
483
484
485
486
                    elif tool_choice_auto:
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        assert tool_parser is not None
                        #TODO optimize manipulation of these lists
                        previous_text = previous_texts[i]
                        previous_token_ids = all_previous_token_ids[i]
                        current_text = previous_text + delta_text
                        current_token_ids = previous_token_ids + list(
                            output.token_ids)

487
488
                        delta_message = (
                            tool_parser.extract_tool_calls_streaming(
489
490
                                previous_text=previous_text,
                                current_text=current_text,
491
                                delta_text=delta_text,
492
493
                                previous_token_ids=previous_token_ids,
                                current_token_ids=current_token_ids,
494
495
                                delta_token_ids=output.token_ids,
                                request=request))
496
497
498
499

                        # update the previous values for the next iteration
                        previous_texts[i] = current_text
                        all_previous_token_ids[i] = current_token_ids
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
                    # reasoning_content cannot be enabled with tool_choice.
                    # If it is, the tool_choice will be used instead.
                    elif self.enable_reasoning:
                        # handle reasoning_content delta
                        assert reasoning_parser is not None
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        previous_text = previous_texts[i]
                        previous_token_ids = all_previous_token_ids[i]
                        current_text = previous_text + delta_text
                        current_token_ids = previous_token_ids + list(
                            output.token_ids)

                        delta_message = (reasoning_parser.
                                         extract_reasoning_content_streaming(
                                             previous_text,
                                             current_text,
                                             delta_text,
                                             previous_token_ids,
                                             current_token_ids,
                                             output.token_ids,
                                         ))

                        # update the previous values for the next iteration
                        previous_texts[i] = current_text
                        all_previous_token_ids[i] = current_token_ids
526
527

                    # handle streaming just a content delta
528
529
530
                    else:
                        delta_message = DeltaMessage(content=delta_text)

531
                    # set the previous values for the next iteration
532
                    previous_num_tokens[i] += len(output.token_ids)
533
534
535
536
537
538
539
540

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
                        continue

541
542
543
544
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
545
                            delta=delta_message,
546
547
                            logprobs=logprobs,
                            finish_reason=None)
548
549

                    # if the model is finished generating
550
                    else:
551
552
553
554
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
                        # only happens if we are NOT using guided decoding
555
                        auto_tools_called = False
556
                        if tool_parser:
557
558
559
560
                            auto_tools_called = len(
                                tool_parser.prev_tool_call_arr) > 0
                            index = len(tool_parser.prev_tool_call_arr
                                        ) - 1 if auto_tools_called else 0
561
562
563
564
565
                        else:
                            index = 0

                        if self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output) and tool_parser:
566
567
568
569
570
571
572
573
574
575
                            latest_delta_len = 0
                            if ((isinstance(
                                    delta_message.tool_calls[0].function,
                                    DeltaFunctionCall)) and isinstance(
                                        delta_message.tool_calls[0].function.
                                        arguments, str)):
                                latest_delta_len = len(
                                    delta_message.tool_calls[0].function.
                                    arguments)

576
577
578
579
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
580
581
                                    "arguments", {}),
                                ensure_ascii=False)
582

583
                            # get what we've streamed so far for arguments
584
585
586
                            # for the current tool
                            actual_call = tool_parser.streamed_args_for_tool[
                                index]
587
588
                            if (latest_delta_len > 0):
                                actual_call = actual_call[:-latest_delta_len]
589
590
591
592
593
594
595
596
597
598
599
600

                            # check to see if there's anything left to stream
                            remaining_call = expected_call.replace(
                                actual_call, "", 1)
                            # set that as a delta message
                            delta_message = DeltaMessage(tool_calls=[
                                DeltaToolCall(index=index,
                                              function=DeltaFunctionCall(
                                                  arguments=remaining_call).
                                              model_dump(exclude_none=True))
                            ])

601
602
603
                        # Send the finish response for each request.n only once
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
604
                            delta=delta_message,
605
                            logprobs=logprobs,
606
                            finish_reason=output.finish_reason
607
                            if not auto_tools_called else "tool_calls",
608
                            stop_reason=output.stop_reason)
609

610
                        finish_reason_sent[i] = True
611

612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
                    chunk = ChatCompletionStreamResponse(
                        id=request_id,
                        object=chunk_object_type,
                        created=created_time,
                        choices=[choice_data],
                        model=model_name)

                    # handle usage stats if requested & if continuous
                    if include_continuous_usage:
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=num_prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=num_prompt_tokens + completion_tokens,
                        )

                    data = chunk.model_dump_json(exclude_unset=True)
                    yield f"data: {data}\n\n"

631
632
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
633
634
            if include_usage:
                completion_tokens = sum(previous_num_tokens)
635
636
637
638
639
640
641
                final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
                                        completion_tokens=completion_tokens,
                                        total_tokens=num_prompt_tokens +
                                        completion_tokens)
                if self.enable_prompt_tokens_details and num_cached_tokens:
                    final_usage.prompt_tokens_details = PromptTokenUsageInfo(
                        cached_tokens=num_cached_tokens)
642
643
644
645
646
647
648
649
650
651
652

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
653

654
655
656
657
658
659
660
            # report to FastAPI middleware aggregate usage across all choices
            num_completion_tokens = sum(previous_num_tokens)
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=num_completion_tokens,
                total_tokens=num_prompt_tokens + num_completion_tokens)

661
        except Exception as e:
662
            # TODO: Use a vllm-specific Validation Error
663
            logger.exception("Error in chat completion stream generator.")
664
665
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
666
667
668
669
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
670
671
672
673
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
674
        model_name: str,
675
        conversation: list[ConversationMessage],
676
        tokenizer: AnyTokenizer,
677
        request_metadata: RequestResponseMetadata,
678
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
679

680
        created_time = int(time.time())
681
        final_res: Optional[RequestOutput] = None
682

683
684
685
686
687
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
688
689
690
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
691

692
693
        assert final_res is not None

694
        choices: list[ChatCompletionResponseChoice] = []
695

696
697
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
698
            token_ids = output.token_ids
699
            out_logprobs = output.logprobs
700

701
702
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
703
                logprobs = self._create_chat_logprobs(
704
                    token_ids=token_ids,
705
                    top_logprobs=out_logprobs,
706
                    num_output_top_logprobs=request.top_logprobs,
707
                    tokenizer=tokenizer,
708
709
710
711
                )
            else:
                logprobs = None

712
713
714
            should_stream_with_reasoning_parsing = (
                self._should_stream_with_reasoning_parsing(request))

715
716
717
718
            # In the OpenAI API the finish_reason is "tools_called"
            # if the tool choice is auto and the model produced a tool
            # call. The same is not true for named function calls
            auto_tools_called = False
719

720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
            if should_stream_with_reasoning_parsing and \
                self.reasoning_parser is not None:
                try:
                    reasoning_parser = self.reasoning_parser(tokenizer)
                except RuntimeError as e:
                    logger.exception("Error in reasoning parser creation.")
                    return self.create_error_response(str(e))

                reasoning_content, content = (
                    reasoning_parser.extract_reasoning_content(
                        output.text, request=request))

                if reasoning_content:
                    message = ChatMessage(role=role,
                                          content=content,
                                          reasoning_content=reasoning_content)
                else:
                    message = ChatMessage(role=role, content=output.text)

739
740
            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
741
742
743
            elif (not self.enable_auto_tools
                  or not self.tool_parser) and not isinstance(
                      request.tool_choice, ChatCompletionNamedToolChoiceParam):
744
745
746
747
                message = ChatMessage(role=role, content=output.text)

            # if the request uses tools and specified a tool choice
            elif request.tool_choice and type(
748
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
749

750
751
                tool_call_class = MistralToolCall if isinstance(
                    tokenizer, MistralTokenizer) else ToolCall
752
753
754
755
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
756
                        tool_call_class(function=FunctionCall(
757
758
759
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
760
761
762

            # if the request doesn't use tool choice
            # OR specifies to not use a tool
763
            elif not request.tool_choice or request.tool_choice == "none":
764
765
766
767
768
769
770
771
772

                message = ChatMessage(role=role, content=output.text)

            # handle when there are tools and tool choice is auto
            elif request.tools and (
                    request.tool_choice == "auto"
                    or request.tool_choice is None) and self.enable_auto_tools \
                    and self.tool_parser:

773
774
775
                try:
                    tool_parser = self.tool_parser(tokenizer)
                except RuntimeError as e:
776
                    logger.exception("Error in tool parser creation.")
777
778
                    return self.create_error_response(str(e))

779
780
                tool_call_info = tool_parser.extract_tool_calls(
                    output.text, request=request)
781
782
783
784
                # In the OpenAI API the finish_reason is "tools_called"
                # if the tool choice is auto and the model produced a tool
                # call. The same is not true for named function calls
                auto_tools_called = tool_call_info.tools_called
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
                if tool_call_info.tools_called:
                    message = ChatMessage(role=role,
                                          content=tool_call_info.content,
                                          tool_calls=tool_call_info.tool_calls)

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
                    message = ChatMessage(role=role, content=output.text)

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
                    "completion.")
801
802
                message = ChatMessage(role=role, content=output.text)

803
804
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
805
                message=message,
806
                logprobs=logprobs,
807
                finish_reason="tool_calls" if auto_tools_called else
808
                output.finish_reason if output.finish_reason else "stop",
809
                stop_reason=output.stop_reason)
810
811
            choices.append(choice_data)

812
        if request.echo:
813
            last_msg_content: Union[str, list[dict[str, str]]] = ""
814
815
            if conversation and "content" in conversation[-1] and conversation[
                    -1].get("role") == role:
816
                last_msg_content = conversation[-1]["content"] or ""
817
818
819
            if isinstance(last_msg_content, list):
                last_msg_content = "\n".join(msg['text']
                                             for msg in last_msg_content)
820
821

            for choice in choices:
822
823
                full_message = last_msg_content + (choice.message.content
                                                   or "")
824
825
                choice.message.content = full_message

826
        assert final_res.prompt_token_ids is not None
827
        num_prompt_tokens = len(final_res.prompt_token_ids)
828
829
        if final_res.encoder_prompt_token_ids is not None:
            num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
830
831
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
832
833
834
835
836
837
838
        usage = UsageInfo(prompt_tokens=num_prompt_tokens,
                          completion_tokens=num_generated_tokens,
                          total_tokens=num_prompt_tokens +
                          num_generated_tokens)
        if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
            usage.prompt_tokens_details = PromptTokenUsageInfo(
                cached_tokens=final_res.num_cached_tokens)
839
840
841

        request_metadata.final_usage_info = usage

842
843
844
845
846
847
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
848
            prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
849
850
        )

851
        return response
852
853

    def _get_top_logprobs(
854
855
            self, logprobs: dict[int, Logprob], top_logprobs: Optional[int],
            tokenizer: AnyTokenizer) -> list[ChatCompletionLogProb]:
856
        return [
857
858
859
860
861
862
863
864
            ChatCompletionLogProb(token=(token := self._get_decoded_token(
                p[1],
                p[0],
                tokenizer,
                return_as_token_id=self.return_tokens_as_token_ids)),
                                  logprob=max(p[1].logprob, -9999.0),
                                  bytes=list(
                                      token.encode("utf-8", errors="replace")))
865
866
867
868
869
870
871
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
872
        top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
873
        tokenizer: AnyTokenizer,
874
875
876
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
877
        logprobs_content: list[ChatCompletionLogProbsContent] = []
878
879
880
881

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
882
                token = tokenizer.decode(token_id)
883
884
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
885

886
887
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
888
                        token=token,
889
890
                        bytes=list(token.encode("utf-8", errors="replace")),
                    ))
891
            else:
892
893
894
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

895
896
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
897
                        token=self._get_decoded_token(
898
899
900
901
902
903
904
905
                            step_token,
                            token_id,
                            tokenizer,
                            self.return_tokens_as_token_ids,
                        ),
                        logprob=max(step_token.logprob, -9999.0),
                        bytes=None if step_decoded is None else list(
                            step_decoded.encode("utf-8", errors="replace")),
906
                        top_logprobs=self._get_top_logprobs(
907
908
909
910
911
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                        ),
                    ))
912
913

        return ChatCompletionLogProbs(content=logprobs_content)
914
915
916
917
918
919
920
921
922
923
924
925
926
927

    def _should_stream_with_auto_tool_parsing(self,
                                              request: ChatCompletionRequest):
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
        return (request.tools and self.tool_parser and self.enable_auto_tools
                and request.tool_choice in ['auto', None])

928
929
930
931
932
933
934
935
936
937
938
    def _should_stream_with_reasoning_parsing(self,
                                              request: ChatCompletionRequest):
        """
            Utility function to check if streamed tokens should go through the
            reasoning parser that was configured.
    
            We only want to do this IF reasoning is enabled and a reasoning 
            parser is configured.
            """
        return self.enable_reasoning and self.reasoning_parser is not None

939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
        delta_message: Optional[DeltaMessage],
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        # yapf: disable
        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
954
955
            output.finish_reason is not None
            and self.enable_auto_tools and self.tool_parser and delta_message
956
957
958
959
            and delta_message.tool_calls and delta_message.tool_calls[0]
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
        )