serving_chat.py 34.7 KB
Newer Older
1
import asyncio
2
import json
3
import time
4
5
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
                    Optional)
6
from typing import Sequence as GenericSequence
7
from typing import Union
8

9
from fastapi import Request
10

11
from vllm.config import ModelConfig
12
from vllm.engine.protocol import AsyncEngineClient
13
from vllm.entrypoints.chat_utils import (ConversationMessage,
14
                                         apply_chat_template,
15
                                         load_chat_template,
16
                                         parse_chat_messages_futures)
17
from vllm.entrypoints.logger import RequestLogger
18
from vllm.entrypoints.openai.protocol import (
19
20
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
21
    ChatCompletionRequest, ChatCompletionResponse,
22
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
23
24
    ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
    DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
25
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
26
                                                    OpenAIServing,
27
28
                                                    PromptAdapterPath,
                                                    TextTokensPrompt)
29
30
31
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
                                                  MistralToolParser,
                                                  ToolParser)
32
from vllm.inputs import TokensPrompt
33
from vllm.logger import init_logger
34
from vllm.outputs import CompletionOutput, RequestOutput
35
from vllm.sequence import Logprob
36
37
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
38
from vllm.transformers_utils.tokenizer import AnyTokenizer
39
from vllm.utils import iterate_with_cancellation, random_uuid
40
41
42
43
44
45

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

46
47
48
49
50
51
52
53
54
55
56
57
58
    def __init__(self,
                 async_engine_client: AsyncEngineClient,
                 model_config: ModelConfig,
                 served_model_names: List[str],
                 response_role: str,
                 *,
                 lora_modules: Optional[List[LoRAModulePath]],
                 prompt_adapters: Optional[List[PromptAdapterPath]],
                 request_logger: Optional[RequestLogger],
                 chat_template: Optional[str],
                 return_tokens_as_token_ids: bool = False,
                 enable_auto_tools: bool = False,
                 tool_parser: Optional[str] = None):
59
        super().__init__(async_engine_client=async_engine_client,
60
                         model_config=model_config,
61
                         served_model_names=served_model_names,
62
63
                         lora_modules=lora_modules,
                         prompt_adapters=prompt_adapters,
64
65
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
66

67
        self.response_role = response_role
68
        self.use_tool_use_model_template = False
69
        self.chat_template = load_chat_template(chat_template)
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
        if self.enable_auto_tools:
            logger.info(
                "\"auto\" tool choice has been enabled please note that while"
                " the parallel_tool_calls client option is preset for "
                "compatibility reasons, it will be ignored.")

        self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
        if self.enable_auto_tools:
            if tool_parser == "mistral":
                self.tool_parser = MistralToolParser
            elif tool_parser == "hermes":
                self.tool_parser = Hermes2ProToolParser
            else:
                raise TypeError("Error: --enable-auto-tool-choice requires "
                                "--tool-call-parser")

89
    async def create_chat_completion(
90
91
        self,
        request: ChatCompletionRequest,
92
93
94
        raw_request: Optional[Request] = None,
    ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
               ErrorResponse]:
95
96
        """Completion API similar to OpenAI's API.

97
98
99
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
100
101
102
103

        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
104
            logger.error("Error with model %s", error_check_ret)
105
106
107
            return error_check_ret

        try:
108
109
110
111
112
113
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            model_config = self.model_config
114
115
            tokenizer = await self.async_engine_client.get_tokenizer(
                lora_request)
116

117
            conversation, mm_data_future = parse_chat_messages_futures(
118
                request.messages, model_config, tokenizer)
119

120
121
122
123
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

124
125
            prompt = apply_chat_template(
                tokenizer,
126
                conversation=conversation,
127
                chat_template=request.chat_template or self.chat_template,
128
                add_generation_prompt=request.add_generation_prompt,
129
130
131
                tools=tool_dicts,
                documents=request.documents,
                **(request.chat_template_kwargs or {}),
132
            )
133
        except Exception as e:
134
            logger.error("Error in applying chat template from request: %s", e)
135
136
            return self.create_error_response(str(e))

137
        try:
138
            mm_data = await mm_data_future
139
        except Exception as e:
140
            logger.error("Error in loading multi-modal data: %s", e)
141
142
            return self.create_error_response(str(e))

143
144
145
146
147
148
149
150
151
152
153
154
155
156
        # validation for OpenAI tools
        # tool_choice = "required" is not supported
        if request.tool_choice == "required":
            return self.create_error_response(
                "tool_choice = \"required\" is not supported!")

            # "auto" tools requires --enable-auto-tool-choice
            # and --tool-call-parser
        if request.tool_choice == "auto" and not (
                self.enable_auto_tools and self.tool_parser is not None):
            return self.create_error_response(
                "\"auto\" tool choice requires "
                "--enable-auto-tool-choice and --tool-call-parser to be set")

157
        request_id = f"chat-{random_uuid()}"
158
        try:
159
            guided_decode_logits_processor = (
160
                await self._guided_decode_logits_processor(request, tokenizer))
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            if isinstance(prompt, str):
                prompt_inputs = self._tokenize_prompt_input(
                    request,
                    tokenizer,
                    prompt,
                    truncate_prompt_tokens=request.truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
                assert isinstance(prompt, list) and isinstance(
                    prompt[0], int
                ), "Prompt has to be either a string or a list of token ids"
                prompt_inputs = TextTokensPrompt(
                    prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)

            assert prompt_inputs is not None
178

179
180
181
182
183
184
            sampling_params = request.to_sampling_params(
                tokenizer,
                guided_decode_logits_processor,
                default_max_tokens=self.max_model_len -
                len(prompt_inputs["prompt_token_ids"]))

185
186
187
188
189
190
            self._log_inputs(request_id,
                             prompt_inputs,
                             params=sampling_params,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)

191
192
            engine_inputs = TokensPrompt(
                prompt_token_ids=prompt_inputs["prompt_token_ids"])
193
194
195
            if mm_data is not None:
                engine_inputs["multi_modal_data"] = mm_data

196
197
            is_tracing_enabled = (
                await self.async_engine_client.is_tracing_enabled())
198
199
200
201
202
203
204
            trace_headers = None
            if is_tracing_enabled and raw_request:
                trace_headers = extract_trace_headers(raw_request.headers)
            if (not is_tracing_enabled and raw_request
                    and contains_trace_headers(raw_request.headers)):
                log_tracing_disabled_warning()

205
            result_generator = self.async_engine_client.generate(
206
207
208
209
210
211
212
                engine_inputs,
                sampling_params,
                request_id,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
            )
213
        except ValueError as e:
214
            # TODO: Use a vllm-specific Validation Error
215
216
            return self.create_error_response(str(e))

217
218
219
220
        if raw_request:
            result_generator = iterate_with_cancellation(
                result_generator, raw_request.is_disconnected)

221
222
223
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
224
                request, result_generator, request_id, conversation, tokenizer)
225

226
227
228
229
230
231
        try:
            return await self.chat_completion_full_generator(
                request, result_generator, request_id, conversation, tokenizer)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
232
233
234
235
236

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
237
            return request.messages[-1]["role"]
238
239

    async def chat_completion_stream_generator(
240
241
242
243
244
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
245
        tokenizer: AnyTokenizer,
246
    ) -> AsyncGenerator[str, None]:
247
        model_name = self.served_model_names[0]
248
        created_time = int(time.time())
249
        chunk_object_type: Final = "chat.completion.chunk"
250
        first_iteration = True
251
252

        # Send response for each token for each request.n (index)
253
254
255
256
257
        num_choices = 1 if request.n is None else request.n
        previous_texts = [""] * num_choices
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices

258
259
260
        tool_parser: Optional[ToolParser] = self.tool_parser(
            tokenizer) if self.tool_parser else None

261
262
263
264
265
266
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
267
268
                    # Send first response for each request.n (index) with
                    # the role
269
                    role = self.get_chat_request_role(request)
270
271
272

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
273
                    for i in range(num_choices):
274
275
276
277
278
279
280
281
282
283
284
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
                            delta=DeltaMessage(role=role),
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
285
286

                        # if usage should be included
287
288
                        if (request.stream_options
                                and request.stream_options.include_usage):
289
290
                            # if continuous usage stats are requested, add it
                            if request.stream_options.continuous_usage_stats:
291
292
293
294
295
                                prompt_tokens = len(res.prompt_token_ids)
                                usage = UsageInfo(prompt_tokens=prompt_tokens,
                                                  completion_tokens=0,
                                                  total_tokens=prompt_tokens)
                                chunk.usage = usage
296
                            # otherwise don't
297
298
299
                            else:
                                chunk.usage = None

300
301
302
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

303
304
                    # Send response to echo the input portion of the
                    # last message
305
                    if request.echo:
306
                        last_msg_content: Optional[str] = ""
307
308
309
310
                        if conversation and conversation[-1].get(
                                "content") and conversation[-1].get(
                                    "role") == role:
                            last_msg_content = conversation[-1]["content"]
311
312

                        if last_msg_content:
313
                            for i in range(num_choices):
314
315
316
317
318
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
319
                                        logprobs=None,
320
                                        finish_reason=None))
321
322
323
324
325
326
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    model=model_name)
327
328
                                if (request.stream_options and
                                        request.stream_options.include_usage):
329
330
331
332
333
334
335
336
337
338
339
340
                                    if (request.stream_options.
                                            continuous_usage_stats):
                                        prompt_tokens = len(
                                            res.prompt_token_ids)
                                        usage = UsageInfo(
                                            prompt_tokens=prompt_tokens,
                                            completion_tokens=0,
                                            total_tokens=prompt_tokens)
                                        chunk.usage = usage
                                    else:
                                        chunk.usage = None

341
342
343
344
345
346
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
347

348
349
350
351
352
353
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
354
                    out_logprobs = output.logprobs[
355
356
                        previous_num_tokens[i]:] if output.logprobs else None

357
358
359
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
360
                        logprobs = self._create_chat_logprobs(
361
                            token_ids=delta_token_ids,
362
                            top_logprobs=out_logprobs,
363
                            tokenizer=tokenizer,
364
                            num_output_top_logprobs=request.top_logprobs,
365
366
367
368
369
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
370
                    delta_message: Optional[DeltaMessage] = None
371

372
373
374
                    # handle streaming deltas for tools with named tool_choice
                    if (request.tool_choice and type(request.tool_choice) is
                            ChatCompletionNamedToolChoiceParam):
375
                        delta_message = DeltaMessage(tool_calls=[
376
                            DeltaToolCall(function=DeltaFunctionCall(
377
                                name=request.tool_choice.function.name,
378
379
                                arguments=delta_text),
                                          index=i)
380
                        ])
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

                    # handle streaming deltas for tools with "auto" tool choice
                    elif (self._should_stream_with_auto_tool_parsing(request)
                          and tool_parser):
                        delta_message = (
                            tool_parser.extract_tool_calls_streaming(
                                previous_text=previous_texts[i],
                                current_text=output.text,
                                delta_text=delta_text,
                                previous_token_ids= \
                                    output.token_ids[
                                    :-1 * len(delta_token_ids)
                                    ],
                                current_token_ids=output.token_ids,
                                delta_token_ids=delta_token_ids
                            )
                        )

                    # handle streaming just a content delta
400
401
402
                    else:
                        delta_message = DeltaMessage(content=delta_text)

403
404
405
406
407
408
409
410
411
412
413
                    # set the previous values for the next iteration
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
                        continue

414
415
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
416

417
418
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
419
                            delta=delta_message,
420
421
422
423
424
425
426
427
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
428
429

                        # handle usage stats if requested & if continuous
430
431
                        if (request.stream_options
                                and request.stream_options.include_usage):
432
433
434
435
436
437
438
439
440
441
442
443
444
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None

445
446
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
447
448

                    # if the model is finished generating
449
                    else:
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
                        # only happens if we are NOT using guided decoding
                        if tool_parser:
                            index = len(
                                tool_parser.prev_tool_call_arr) - 1 if len(
                                    tool_parser.prev_tool_call_arr) > 0 else 0
                        else:
                            index = 0

                        if self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output) and tool_parser:
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
                                    "arguments", {}))

                            # get what we've streamed so for for arguments
                            # for the current tool
                            actual_call = tool_parser.streamed_args_for_tool[
                                index]

                            # check to see if there's anything left to stream
                            remaining_call = expected_call.replace(
                                actual_call, "", 1)

                            # set that as a delta message
                            delta_message = DeltaMessage(tool_calls=[
                                DeltaToolCall(index=index,
                                              function=DeltaFunctionCall(
                                                  arguments=remaining_call).
                                              model_dump(exclude_none=True))
                            ])

486
487
488
489
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
490
                            delta=delta_message,
491
                            logprobs=logprobs,
492
493
494
495
                            finish_reason=output.finish_reason
                            if not (tool_parser
                                    and len(tool_parser.prev_tool_call_arr))
                            else "tool_calls",
496
                            stop_reason=output.stop_reason)
497
498
499
500
501
502
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
503
504
                        if (request.stream_options
                                and request.stream_options.include_usage):
505
506
507
508
509
510
511
512
513
514
515
516
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None
517
                        data = chunk.model_dump_json(exclude_unset=True)
518
519
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
520

521
522
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
541

542
543
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
544
            logger.error("error in chat completion stream generator: %s", e)
545
546
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
547
548
549
550
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
551
552
553
554
555
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
556
        tokenizer: AnyTokenizer,
557
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
558

559
        model_name = self.served_model_names[0]
560
        created_time = int(time.time())
561
        final_res: Optional[RequestOutput] = None
562

563
564
565
566
567
568
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")

569
570
        assert final_res is not None

571
        choices: List[ChatCompletionResponseChoice] = []
572

573
574
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
575
            token_ids = output.token_ids
576
            out_logprobs = output.logprobs
577

578
579
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
580
                logprobs = self._create_chat_logprobs(
581
                    token_ids=token_ids,
582
                    top_logprobs=out_logprobs,
583
                    num_output_top_logprobs=request.top_logprobs,
584
                    tokenizer=tokenizer,
585
586
587
588
                )
            else:
                logprobs = None

589
590
591
592
593
594
595
596
597
598
599
600
601
            # by default, tools are not used.
            tools_called = False

            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
            if not (self.enable_auto_tools
                    or not self.tool_parser) and not isinstance(
                        request.tool_choice,
                        ChatCompletionNamedToolChoiceParam):
                message = ChatMessage(role=role, content=output.text)

            # if the request uses tools and specified a tool choice
            elif request.tool_choice and type(
602
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
603

604
605
606
607
608
609
610
611
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
612
613
614
615
                tools_called = True

            # if the request doesn't use tool choice
            # OR specifies to not use a tool
616
            elif not request.tool_choice or request.tool_choice == "none":
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644

                message = ChatMessage(role=role, content=output.text)

            # handle when there are tools and tool choice is auto
            elif request.tools and (
                    request.tool_choice == "auto"
                    or request.tool_choice is None) and self.enable_auto_tools \
                    and self.tool_parser:

                tool_parser = self.tool_parser(tokenizer)
                tool_call_info = tool_parser.extract_tool_calls(output.text)
                tools_called = tool_call_info.tools_called
                if tool_call_info.tools_called:
                    message = ChatMessage(role=role,
                                          content=tool_call_info.content,
                                          tool_calls=tool_call_info.tool_calls)

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
                    message = ChatMessage(role=role, content=output.text)

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
                    "completion.")
645
646
                message = ChatMessage(role=role, content=output.text)

647
648
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
649
                message=message,
650
                logprobs=logprobs,
651
652
                finish_reason="tool_calls" if tools_called else
                output.finish_reason if output.finish_reason else "stop",
653
                stop_reason=output.stop_reason)
654
655
656
657
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
658
659
            if conversation and conversation[-1].get(
                    "content") and conversation[-1].get("role") == role:
660
                last_msg_content = conversation[-1]["content"] or ""
661
662

            for choice in choices:
663
664
                full_message = last_msg_content + (choice.message.content
                                                   or "")
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
681
            prompt_logprobs=final_res.prompt_logprobs,
682
683
        )

684
        return response
685
686

    def _get_top_logprobs(
687
            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
688
            tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
689
        return [
690
691
692
693
694
695
696
697
            ChatCompletionLogProb(token=(token := self._get_decoded_token(
                p[1],
                p[0],
                tokenizer,
                return_as_token_id=self.return_tokens_as_token_ids)),
                                  logprob=max(p[1].logprob, -9999.0),
                                  bytes=list(
                                      token.encode("utf-8", errors="replace")))
698
699
700
701
702
703
704
705
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
706
        tokenizer: AnyTokenizer,
707
708
709
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
710
        logprobs_content: List[ChatCompletionLogProbsContent] = []
711
712
713
714

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
715
                token = tokenizer.decode(token_id)
716
717
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
718

719
720
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
721
                        token=token,
722
723
                        bytes=list(token.encode("utf-8", errors="replace")),
                    ))
724
            else:
725
726
727
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

728
729
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
730
                        token=self._get_decoded_token(
731
732
733
734
735
736
737
738
                            step_token,
                            token_id,
                            tokenizer,
                            self.return_tokens_as_token_ids,
                        ),
                        logprob=max(step_token.logprob, -9999.0),
                        bytes=None if step_decoded is None else list(
                            step_decoded.encode("utf-8", errors="replace")),
739
                        top_logprobs=self._get_top_logprobs(
740
741
742
743
744
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                        ),
                    ))
745
746

        return ChatCompletionLogProbs(content=logprobs_content)
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781

    def _should_stream_with_auto_tool_parsing(self,
                                              request: ChatCompletionRequest):
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
        return (request.tools and self.tool_parser and self.enable_auto_tools
                and request.tool_choice in ['auto', None])

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
        delta_message: Optional[DeltaMessage],
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        # yapf: disable
        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
            self.enable_auto_tools and self.tool_parser and delta_message
            and delta_message.tool_calls and delta_message.tool_calls[0]
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
            and output.finish_reason is not None
        )