serving_chat.py 34.8 KB
Newer Older
1
import asyncio
2
import json
3
import time
4
5
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
                    Optional)
6
from typing import Sequence as GenericSequence
7
from typing import Union
8

9
from fastapi import Request
10

11
from vllm.config import ModelConfig
12
from vllm.engine.protocol import AsyncEngineClient
13
from vllm.entrypoints.chat_utils import (ConversationMessage,
14
                                         apply_chat_template,
15
                                         load_chat_template,
16
                                         parse_chat_messages_futures)
17
from vllm.entrypoints.logger import RequestLogger
18
from vllm.entrypoints.openai.protocol import (
19
20
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
21
    ChatCompletionRequest, ChatCompletionResponse,
22
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
23
24
    ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
    DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
25
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
26
                                                    OpenAIServing,
27
28
                                                    PromptAdapterPath,
                                                    TextTokensPrompt)
29
30
31
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
                                                  MistralToolParser,
                                                  ToolParser)
32
from vllm.inputs import TokensPrompt
33
from vllm.logger import init_logger
34
from vllm.outputs import CompletionOutput, RequestOutput
35
from vllm.sequence import Logprob
36
37
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
38
from vllm.transformers_utils.tokenizer import AnyTokenizer
39
from vllm.utils import iterate_with_cancellation, random_uuid
40
41
42
43
44
45

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

46
47
48
49
50
51
52
53
54
55
56
57
58
    def __init__(self,
                 async_engine_client: AsyncEngineClient,
                 model_config: ModelConfig,
                 served_model_names: List[str],
                 response_role: str,
                 *,
                 lora_modules: Optional[List[LoRAModulePath]],
                 prompt_adapters: Optional[List[PromptAdapterPath]],
                 request_logger: Optional[RequestLogger],
                 chat_template: Optional[str],
                 return_tokens_as_token_ids: bool = False,
                 enable_auto_tools: bool = False,
                 tool_parser: Optional[str] = None):
59
        super().__init__(async_engine_client=async_engine_client,
60
                         model_config=model_config,
61
                         served_model_names=served_model_names,
62
63
                         lora_modules=lora_modules,
                         prompt_adapters=prompt_adapters,
64
65
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
66

67
        self.response_role = response_role
68
        self.use_tool_use_model_template = False
69
        self.chat_template = load_chat_template(chat_template)
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
        if self.enable_auto_tools:
            logger.info(
                "\"auto\" tool choice has been enabled please note that while"
                " the parallel_tool_calls client option is preset for "
                "compatibility reasons, it will be ignored.")

        self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
        if self.enable_auto_tools:
            if tool_parser == "mistral":
                self.tool_parser = MistralToolParser
            elif tool_parser == "hermes":
                self.tool_parser = Hermes2ProToolParser
            else:
                raise TypeError("Error: --enable-auto-tool-choice requires "
                                "--tool-call-parser")

89
    async def create_chat_completion(
90
91
        self,
        request: ChatCompletionRequest,
92
93
94
        raw_request: Optional[Request] = None,
    ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
               ErrorResponse]:
95
96
        """Completion API similar to OpenAI's API.

97
98
99
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
100
101
102
103

        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
104
            logger.error("Error with model %s", error_check_ret)
105
106
107
            return error_check_ret

        try:
108
109
110
111
112
113
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            model_config = self.model_config
114
115
            tokenizer = await self.async_engine_client.get_tokenizer(
                lora_request)
116

117
            conversation, mm_data_future = parse_chat_messages_futures(
118
                request.messages, model_config, tokenizer)
119

120
121
122
123
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

124
125
            prompt = apply_chat_template(
                tokenizer,
126
                conversation=conversation,
127
                chat_template=request.chat_template or self.chat_template,
128
                add_generation_prompt=request.add_generation_prompt,
129
130
131
                tools=tool_dicts,
                documents=request.documents,
                **(request.chat_template_kwargs or {}),
132
            )
133
        except Exception as e:
134
            logger.error("Error in applying chat template from request: %s", e)
135
136
            return self.create_error_response(str(e))

137
        try:
138
            mm_data = await mm_data_future
139
        except Exception as e:
140
            logger.error("Error in loading multi-modal data: %s", e)
141
142
            return self.create_error_response(str(e))

143
144
145
146
147
148
149
150
151
152
153
154
155
156
        # validation for OpenAI tools
        # tool_choice = "required" is not supported
        if request.tool_choice == "required":
            return self.create_error_response(
                "tool_choice = \"required\" is not supported!")

            # "auto" tools requires --enable-auto-tool-choice
            # and --tool-call-parser
        if request.tool_choice == "auto" and not (
                self.enable_auto_tools and self.tool_parser is not None):
            return self.create_error_response(
                "\"auto\" tool choice requires "
                "--enable-auto-tool-choice and --tool-call-parser to be set")

157
        request_id = f"chat-{random_uuid()}"
158
        try:
159
            guided_decode_logits_processor = (
160
                await self._guided_decode_logits_processor(request, tokenizer))
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            if isinstance(prompt, str):
                prompt_inputs = self._tokenize_prompt_input(
                    request,
                    tokenizer,
                    prompt,
                    truncate_prompt_tokens=request.truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
                assert isinstance(prompt, list) and isinstance(
                    prompt[0], int
                ), "Prompt has to be either a string or a list of token ids"
                prompt_inputs = TextTokensPrompt(
                    prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)

            assert prompt_inputs is not None
178

179
180
181
182
183
184
            sampling_params = request.to_sampling_params(
                tokenizer,
                guided_decode_logits_processor,
                default_max_tokens=self.max_model_len -
                len(prompt_inputs["prompt_token_ids"]))

185
186
187
188
189
190
            self._log_inputs(request_id,
                             prompt_inputs,
                             params=sampling_params,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)

191
192
            engine_inputs = TokensPrompt(
                prompt_token_ids=prompt_inputs["prompt_token_ids"])
193
194
195
            if mm_data is not None:
                engine_inputs["multi_modal_data"] = mm_data

196
197
            is_tracing_enabled = (
                await self.async_engine_client.is_tracing_enabled())
198
199
200
201
202
203
204
            trace_headers = None
            if is_tracing_enabled and raw_request:
                trace_headers = extract_trace_headers(raw_request.headers)
            if (not is_tracing_enabled and raw_request
                    and contains_trace_headers(raw_request.headers)):
                log_tracing_disabled_warning()

205
            result_generator = self.async_engine_client.generate(
206
207
208
209
210
211
212
                engine_inputs,
                sampling_params,
                request_id,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
            )
213
        except ValueError as e:
214
            # TODO: Use a vllm-specific Validation Error
215
216
            return self.create_error_response(str(e))

217
218
219
220
        if raw_request:
            result_generator = iterate_with_cancellation(
                result_generator, raw_request.is_disconnected)

221
222
223
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
224
                request, result_generator, request_id, conversation, tokenizer)
225

226
227
228
229
230
231
        try:
            return await self.chat_completion_full_generator(
                request, result_generator, request_id, conversation, tokenizer)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
232
233
234
235
236

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
237
            return request.messages[-1]["role"]
238
239

    async def chat_completion_stream_generator(
240
241
242
243
244
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
245
        tokenizer: AnyTokenizer,
246
    ) -> AsyncGenerator[str, None]:
247
        model_name = self.served_model_names[0]
248
        created_time = int(time.time())
249
        chunk_object_type: Final = "chat.completion.chunk"
250
        first_iteration = True
251
252

        # Send response for each token for each request.n (index)
253
254
255
256
257
        num_choices = 1 if request.n is None else request.n
        previous_texts = [""] * num_choices
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices

258
259
260
        tool_parser: Optional[ToolParser] = self.tool_parser(
            tokenizer) if self.tool_parser else None

261
262
263
264
265
266
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
267
268
                    # Send first response for each request.n (index) with
                    # the role
269
                    role = self.get_chat_request_role(request)
270
271
272

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
273
                    for i in range(num_choices):
274

275
276
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
277
278
279
280
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
281
282
283
284
285
286
287
288
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
289
290

                        # if usage should be included
291
292
                        if (request.stream_options
                                and request.stream_options.include_usage):
293
294
                            # if continuous usage stats are requested, add it
                            if request.stream_options.continuous_usage_stats:
295
296
297
298
299
                                prompt_tokens = len(res.prompt_token_ids)
                                usage = UsageInfo(prompt_tokens=prompt_tokens,
                                                  completion_tokens=0,
                                                  total_tokens=prompt_tokens)
                                chunk.usage = usage
300
                            # otherwise don't
301
302
303
                            else:
                                chunk.usage = None

304
305
306
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

307
308
                    # Send response to echo the input portion of the
                    # last message
309
                    if request.echo:
310
                        last_msg_content: Optional[str] = ""
311
312
313
314
                        if conversation and conversation[-1].get(
                                "content") and conversation[-1].get(
                                    "role") == role:
                            last_msg_content = conversation[-1]["content"]
315
316

                        if last_msg_content:
317
                            for i in range(num_choices):
318
319
320
321
322
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
323
                                        logprobs=None,
324
                                        finish_reason=None))
325
326
327
328
329
330
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    model=model_name)
331
332
                                if (request.stream_options and
                                        request.stream_options.include_usage):
333
334
335
336
337
338
339
340
341
342
343
344
                                    if (request.stream_options.
                                            continuous_usage_stats):
                                        prompt_tokens = len(
                                            res.prompt_token_ids)
                                        usage = UsageInfo(
                                            prompt_tokens=prompt_tokens,
                                            completion_tokens=0,
                                            total_tokens=prompt_tokens)
                                        chunk.usage = usage
                                    else:
                                        chunk.usage = None

345
346
347
348
349
350
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
351

352
353
354
355
356
357
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
358
                    out_logprobs = output.logprobs[
359
360
                        previous_num_tokens[i]:] if output.logprobs else None

361
362
363
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
364
                        logprobs = self._create_chat_logprobs(
365
                            token_ids=delta_token_ids,
366
                            top_logprobs=out_logprobs,
367
                            tokenizer=tokenizer,
368
                            num_output_top_logprobs=request.top_logprobs,
369
370
371
372
373
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
374
                    delta_message: Optional[DeltaMessage] = None
375

376
377
378
                    # handle streaming deltas for tools with named tool_choice
                    if (request.tool_choice and type(request.tool_choice) is
                            ChatCompletionNamedToolChoiceParam):
379
                        delta_message = DeltaMessage(tool_calls=[
380
                            DeltaToolCall(function=DeltaFunctionCall(
381
                                name=request.tool_choice.function.name,
382
383
                                arguments=delta_text),
                                          index=i)
384
                        ])
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403

                    # handle streaming deltas for tools with "auto" tool choice
                    elif (self._should_stream_with_auto_tool_parsing(request)
                          and tool_parser):
                        delta_message = (
                            tool_parser.extract_tool_calls_streaming(
                                previous_text=previous_texts[i],
                                current_text=output.text,
                                delta_text=delta_text,
                                previous_token_ids= \
                                    output.token_ids[
                                    :-1 * len(delta_token_ids)
                                    ],
                                current_token_ids=output.token_ids,
                                delta_token_ids=delta_token_ids
                            )
                        )

                    # handle streaming just a content delta
404
405
406
                    else:
                        delta_message = DeltaMessage(content=delta_text)

407
408
409
410
411
412
413
414
415
416
417
                    # set the previous values for the next iteration
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
                        continue

418
419
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
420

421
422
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
423
                            delta=delta_message,
424
425
426
427
428
429
430
431
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
432
433

                        # handle usage stats if requested & if continuous
434
435
                        if (request.stream_options
                                and request.stream_options.include_usage):
436
437
438
439
440
441
442
443
444
445
446
447
448
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None

449
450
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
451
452

                    # if the model is finished generating
453
                    else:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
                        # only happens if we are NOT using guided decoding
                        if tool_parser:
                            index = len(
                                tool_parser.prev_tool_call_arr) - 1 if len(
                                    tool_parser.prev_tool_call_arr) > 0 else 0
                        else:
                            index = 0

                        if self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output) and tool_parser:
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
                                    "arguments", {}))

                            # get what we've streamed so for for arguments
                            # for the current tool
                            actual_call = tool_parser.streamed_args_for_tool[
                                index]

                            # check to see if there's anything left to stream
                            remaining_call = expected_call.replace(
                                actual_call, "", 1)

                            # set that as a delta message
                            delta_message = DeltaMessage(tool_calls=[
                                DeltaToolCall(index=index,
                                              function=DeltaFunctionCall(
                                                  arguments=remaining_call).
                                              model_dump(exclude_none=True))
                            ])

490
491
492
493
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
494
                            delta=delta_message,
495
                            logprobs=logprobs,
496
497
498
499
                            finish_reason=output.finish_reason
                            if not (tool_parser
                                    and len(tool_parser.prev_tool_call_arr))
                            else "tool_calls",
500
                            stop_reason=output.stop_reason)
501
502
503
504
505
506
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
507
508
                        if (request.stream_options
                                and request.stream_options.include_usage):
509
510
511
512
513
514
515
516
517
518
519
520
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None
521
                        data = chunk.model_dump_json(exclude_unset=True)
522
523
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
524

525
526
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
545

546
547
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
548
            logger.error("error in chat completion stream generator: %s", e)
549
550
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
551
552
553
554
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
555
556
557
558
559
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
560
        tokenizer: AnyTokenizer,
561
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
562

563
        model_name = self.served_model_names[0]
564
        created_time = int(time.time())
565
        final_res: Optional[RequestOutput] = None
566

567
568
569
570
571
572
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")

573
574
        assert final_res is not None

575
        choices: List[ChatCompletionResponseChoice] = []
576

577
578
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
579
            token_ids = output.token_ids
580
            out_logprobs = output.logprobs
581

582
583
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
584
                logprobs = self._create_chat_logprobs(
585
                    token_ids=token_ids,
586
                    top_logprobs=out_logprobs,
587
                    num_output_top_logprobs=request.top_logprobs,
588
                    tokenizer=tokenizer,
589
590
591
592
                )
            else:
                logprobs = None

593
594
595
596
597
598
599
600
601
602
603
604
605
            # by default, tools are not used.
            tools_called = False

            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
            if not (self.enable_auto_tools
                    or not self.tool_parser) and not isinstance(
                        request.tool_choice,
                        ChatCompletionNamedToolChoiceParam):
                message = ChatMessage(role=role, content=output.text)

            # if the request uses tools and specified a tool choice
            elif request.tool_choice and type(
606
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
607

608
609
610
611
612
613
614
615
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
616
617
618
619
                tools_called = True

            # if the request doesn't use tool choice
            # OR specifies to not use a tool
620
            elif not request.tool_choice or request.tool_choice == "none":
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648

                message = ChatMessage(role=role, content=output.text)

            # handle when there are tools and tool choice is auto
            elif request.tools and (
                    request.tool_choice == "auto"
                    or request.tool_choice is None) and self.enable_auto_tools \
                    and self.tool_parser:

                tool_parser = self.tool_parser(tokenizer)
                tool_call_info = tool_parser.extract_tool_calls(output.text)
                tools_called = tool_call_info.tools_called
                if tool_call_info.tools_called:
                    message = ChatMessage(role=role,
                                          content=tool_call_info.content,
                                          tool_calls=tool_call_info.tool_calls)

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
                    message = ChatMessage(role=role, content=output.text)

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
                    "completion.")
649
650
                message = ChatMessage(role=role, content=output.text)

651
652
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
653
                message=message,
654
                logprobs=logprobs,
655
656
                finish_reason="tool_calls" if tools_called else
                output.finish_reason if output.finish_reason else "stop",
657
                stop_reason=output.stop_reason)
658
659
660
661
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
662
663
            if conversation and conversation[-1].get(
                    "content") and conversation[-1].get("role") == role:
664
                last_msg_content = conversation[-1]["content"] or ""
665
666

            for choice in choices:
667
668
                full_message = last_msg_content + (choice.message.content
                                                   or "")
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
685
            prompt_logprobs=final_res.prompt_logprobs,
686
687
        )

688
        return response
689
690

    def _get_top_logprobs(
691
            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
692
            tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
693
        return [
694
695
696
697
698
699
700
701
            ChatCompletionLogProb(token=(token := self._get_decoded_token(
                p[1],
                p[0],
                tokenizer,
                return_as_token_id=self.return_tokens_as_token_ids)),
                                  logprob=max(p[1].logprob, -9999.0),
                                  bytes=list(
                                      token.encode("utf-8", errors="replace")))
702
703
704
705
706
707
708
709
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
710
        tokenizer: AnyTokenizer,
711
712
713
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
714
        logprobs_content: List[ChatCompletionLogProbsContent] = []
715
716
717
718

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
719
                token = tokenizer.decode(token_id)
720
721
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
722

723
724
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
725
                        token=token,
726
727
                        bytes=list(token.encode("utf-8", errors="replace")),
                    ))
728
            else:
729
730
731
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

732
733
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
734
                        token=self._get_decoded_token(
735
736
737
738
739
740
741
742
                            step_token,
                            token_id,
                            tokenizer,
                            self.return_tokens_as_token_ids,
                        ),
                        logprob=max(step_token.logprob, -9999.0),
                        bytes=None if step_decoded is None else list(
                            step_decoded.encode("utf-8", errors="replace")),
743
                        top_logprobs=self._get_top_logprobs(
744
745
746
747
748
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                        ),
                    ))
749
750

        return ChatCompletionLogProbs(content=logprobs_content)
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785

    def _should_stream_with_auto_tool_parsing(self,
                                              request: ChatCompletionRequest):
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
        return (request.tools and self.tool_parser and self.enable_auto_tools
                and request.tool_choice in ['auto', None])

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
        delta_message: Optional[DeltaMessage],
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        # yapf: disable
        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
            self.enable_auto_tools and self.tool_parser and delta_message
            and delta_message.tool_calls and delta_message.tool_calls[0]
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
            and output.finish_reason is not None
        )