serving_chat.py 35.4 KB
Newer Older
1
import asyncio
2
import json
3
import time
4
5
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
                    Optional)
6
from typing import Sequence as GenericSequence
7
from typing import Union
8

9
from fastapi import Request
10

11
from vllm.config import ModelConfig
12
from vllm.engine.protocol import AsyncEngineClient
13
from vllm.entrypoints.chat_utils import (ConversationMessage,
14
15
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
16
                                         load_chat_template,
17
                                         parse_chat_messages_futures)
18
from vllm.entrypoints.logger import RequestLogger
19
from vllm.entrypoints.openai.protocol import (
20
21
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
22
    ChatCompletionRequest, ChatCompletionResponse,
23
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
24
25
    ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
    DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
26
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
27
                                                    OpenAIServing,
28
29
                                                    PromptAdapterPath,
                                                    TextTokensPrompt)
30
31
32
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
                                                  MistralToolParser,
                                                  ToolParser)
33
from vllm.inputs import TokensPrompt
34
from vllm.logger import init_logger
35
from vllm.outputs import CompletionOutput, RequestOutput
36
from vllm.sequence import Logprob
37
38
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
39
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
40
from vllm.utils import iterate_with_cancellation, random_uuid
41
42
43
44
45
46

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

47
48
49
50
51
52
53
54
55
56
57
58
59
    def __init__(self,
                 async_engine_client: AsyncEngineClient,
                 model_config: ModelConfig,
                 served_model_names: List[str],
                 response_role: str,
                 *,
                 lora_modules: Optional[List[LoRAModulePath]],
                 prompt_adapters: Optional[List[PromptAdapterPath]],
                 request_logger: Optional[RequestLogger],
                 chat_template: Optional[str],
                 return_tokens_as_token_ids: bool = False,
                 enable_auto_tools: bool = False,
                 tool_parser: Optional[str] = None):
60
        super().__init__(async_engine_client=async_engine_client,
61
                         model_config=model_config,
62
                         served_model_names=served_model_names,
63
64
                         lora_modules=lora_modules,
                         prompt_adapters=prompt_adapters,
65
66
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
67

68
        self.response_role = response_role
69
        self.use_tool_use_model_template = False
70
        self.chat_template = load_chat_template(chat_template)
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
        if self.enable_auto_tools:
            logger.info(
                "\"auto\" tool choice has been enabled please note that while"
                " the parallel_tool_calls client option is preset for "
                "compatibility reasons, it will be ignored.")

        self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
        if self.enable_auto_tools:
            if tool_parser == "mistral":
                self.tool_parser = MistralToolParser
            elif tool_parser == "hermes":
                self.tool_parser = Hermes2ProToolParser
            else:
                raise TypeError("Error: --enable-auto-tool-choice requires "
                                "--tool-call-parser")

90
    async def create_chat_completion(
91
92
        self,
        request: ChatCompletionRequest,
93
94
95
        raw_request: Optional[Request] = None,
    ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
               ErrorResponse]:
96
97
        """Completion API similar to OpenAI's API.

98
99
100
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
101
102
103
104

        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
105
            logger.error("Error with model %s", error_check_ret)
106
107
108
            return error_check_ret

        try:
109
110
111
112
113
114
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            model_config = self.model_config
115
116
            tokenizer = await self.async_engine_client.get_tokenizer(
                lora_request)
117

118
            conversation, mm_data_future = parse_chat_messages_futures(
119
                request.messages, model_config, tokenizer)
120

121
122
123
124
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            prompt: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt = apply_mistral_chat_template(
                    tokenizer,
                    messages=request.messages,
                    chat_template=request.chat_template or self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
                    tools=tool_dicts,
                    documents=request.documents,
                    **(request.chat_template_kwargs or {}),
                )
            else:
                prompt = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=request.chat_template or self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
                    tools=tool_dicts,
                    documents=request.documents,
                    **(request.chat_template_kwargs or {}),
                )
146
        except Exception as e:
147
            logger.error("Error in applying chat template from request: %s", e)
148
149
            return self.create_error_response(str(e))

150
        try:
151
            mm_data = await mm_data_future
152
        except Exception as e:
153
            logger.error("Error in loading multi-modal data: %s", e)
154
155
            return self.create_error_response(str(e))

156
157
158
159
160
161
162
163
164
165
166
167
168
169
        # validation for OpenAI tools
        # tool_choice = "required" is not supported
        if request.tool_choice == "required":
            return self.create_error_response(
                "tool_choice = \"required\" is not supported!")

            # "auto" tools requires --enable-auto-tool-choice
            # and --tool-call-parser
        if request.tool_choice == "auto" and not (
                self.enable_auto_tools and self.tool_parser is not None):
            return self.create_error_response(
                "\"auto\" tool choice requires "
                "--enable-auto-tool-choice and --tool-call-parser to be set")

170
        request_id = f"chat-{random_uuid()}"
171
        try:
172
            guided_decode_logits_processor = (
173
                await self._guided_decode_logits_processor(request, tokenizer))
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            if isinstance(prompt, str):
                prompt_inputs = self._tokenize_prompt_input(
                    request,
                    tokenizer,
                    prompt,
                    truncate_prompt_tokens=request.truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
                assert isinstance(prompt, list) and isinstance(
                    prompt[0], int
                ), "Prompt has to be either a string or a list of token ids"
                prompt_inputs = TextTokensPrompt(
                    prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)

            assert prompt_inputs is not None
191

192
193
194
195
196
197
            sampling_params = request.to_sampling_params(
                tokenizer,
                guided_decode_logits_processor,
                default_max_tokens=self.max_model_len -
                len(prompt_inputs["prompt_token_ids"]))

198
199
200
201
202
203
            self._log_inputs(request_id,
                             prompt_inputs,
                             params=sampling_params,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)

204
205
            engine_inputs = TokensPrompt(
                prompt_token_ids=prompt_inputs["prompt_token_ids"])
206
207
208
            if mm_data is not None:
                engine_inputs["multi_modal_data"] = mm_data

209
210
            is_tracing_enabled = (
                await self.async_engine_client.is_tracing_enabled())
211
212
213
214
215
216
217
            trace_headers = None
            if is_tracing_enabled and raw_request:
                trace_headers = extract_trace_headers(raw_request.headers)
            if (not is_tracing_enabled and raw_request
                    and contains_trace_headers(raw_request.headers)):
                log_tracing_disabled_warning()

218
            result_generator = self.async_engine_client.generate(
219
220
221
222
223
224
225
                engine_inputs,
                sampling_params,
                request_id,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
            )
226
        except ValueError as e:
227
            # TODO: Use a vllm-specific Validation Error
228
229
            return self.create_error_response(str(e))

230
231
232
233
        if raw_request:
            result_generator = iterate_with_cancellation(
                result_generator, raw_request.is_disconnected)

234
235
236
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
237
                request, result_generator, request_id, conversation, tokenizer)
238

239
240
241
242
243
244
        try:
            return await self.chat_completion_full_generator(
                request, result_generator, request_id, conversation, tokenizer)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
245
246
247
248
249

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
250
            return request.messages[-1]["role"]
251
252

    async def chat_completion_stream_generator(
253
254
255
256
257
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
258
        tokenizer: AnyTokenizer,
259
    ) -> AsyncGenerator[str, None]:
260
        model_name = self.served_model_names[0]
261
        created_time = int(time.time())
262
        chunk_object_type: Final = "chat.completion.chunk"
263
        first_iteration = True
264
265

        # Send response for each token for each request.n (index)
266
267
268
269
270
        num_choices = 1 if request.n is None else request.n
        previous_texts = [""] * num_choices
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices

271
272
273
        tool_parser: Optional[ToolParser] = self.tool_parser(
            tokenizer) if self.tool_parser else None

274
275
276
277
278
279
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
280
281
                    # Send first response for each request.n (index) with
                    # the role
282
                    role = self.get_chat_request_role(request)
283
284
285

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
286
                    for i in range(num_choices):
287

288
289
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
290
291
292
293
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
294
295
296
297
298
299
300
301
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
302
303

                        # if usage should be included
304
305
                        if (request.stream_options
                                and request.stream_options.include_usage):
306
307
                            # if continuous usage stats are requested, add it
                            if request.stream_options.continuous_usage_stats:
308
309
310
311
312
                                prompt_tokens = len(res.prompt_token_ids)
                                usage = UsageInfo(prompt_tokens=prompt_tokens,
                                                  completion_tokens=0,
                                                  total_tokens=prompt_tokens)
                                chunk.usage = usage
313
                            # otherwise don't
314
315
316
                            else:
                                chunk.usage = None

317
318
319
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

320
321
                    # Send response to echo the input portion of the
                    # last message
322
                    if request.echo:
323
324
325
326
                        last_msg_content: str = ""
                        if conversation and "content" in conversation[
                                -1] and conversation[-1].get("role") == role:
                            last_msg_content = conversation[-1]["content"] or ""
327
328

                        if last_msg_content:
329
                            for i in range(num_choices):
330
331
332
333
334
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
335
                                        logprobs=None,
336
                                        finish_reason=None))
337
338
339
340
341
342
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    model=model_name)
343
344
                                if (request.stream_options and
                                        request.stream_options.include_usage):
345
346
347
348
349
350
351
352
353
354
355
356
                                    if (request.stream_options.
                                            continuous_usage_stats):
                                        prompt_tokens = len(
                                            res.prompt_token_ids)
                                        usage = UsageInfo(
                                            prompt_tokens=prompt_tokens,
                                            completion_tokens=0,
                                            total_tokens=prompt_tokens)
                                        chunk.usage = usage
                                    else:
                                        chunk.usage = None

357
358
359
360
361
362
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
363

364
365
366
367
368
369
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
370
                    out_logprobs = output.logprobs[
371
372
                        previous_num_tokens[i]:] if output.logprobs else None

373
374
375
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
376
                        logprobs = self._create_chat_logprobs(
377
                            token_ids=delta_token_ids,
378
                            top_logprobs=out_logprobs,
379
                            tokenizer=tokenizer,
380
                            num_output_top_logprobs=request.top_logprobs,
381
382
383
384
385
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
386
                    delta_message: Optional[DeltaMessage] = None
387

388
389
390
                    # handle streaming deltas for tools with named tool_choice
                    if (request.tool_choice and type(request.tool_choice) is
                            ChatCompletionNamedToolChoiceParam):
391
                        delta_message = DeltaMessage(tool_calls=[
392
                            DeltaToolCall(function=DeltaFunctionCall(
393
                                name=request.tool_choice.function.name,
394
395
                                arguments=delta_text),
                                          index=i)
396
                        ])
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415

                    # handle streaming deltas for tools with "auto" tool choice
                    elif (self._should_stream_with_auto_tool_parsing(request)
                          and tool_parser):
                        delta_message = (
                            tool_parser.extract_tool_calls_streaming(
                                previous_text=previous_texts[i],
                                current_text=output.text,
                                delta_text=delta_text,
                                previous_token_ids= \
                                    output.token_ids[
                                    :-1 * len(delta_token_ids)
                                    ],
                                current_token_ids=output.token_ids,
                                delta_token_ids=delta_token_ids
                            )
                        )

                    # handle streaming just a content delta
416
417
418
                    else:
                        delta_message = DeltaMessage(content=delta_text)

419
420
421
422
423
424
425
426
427
428
429
                    # set the previous values for the next iteration
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
                        continue

430
431
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
432

433
434
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
435
                            delta=delta_message,
436
437
438
439
440
441
442
443
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
444
445

                        # handle usage stats if requested & if continuous
446
447
                        if (request.stream_options
                                and request.stream_options.include_usage):
448
449
450
451
452
453
454
455
456
457
458
459
460
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None

461
462
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
463
464

                    # if the model is finished generating
465
                    else:
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
                        # only happens if we are NOT using guided decoding
                        if tool_parser:
                            index = len(
                                tool_parser.prev_tool_call_arr) - 1 if len(
                                    tool_parser.prev_tool_call_arr) > 0 else 0
                        else:
                            index = 0

                        if self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output) and tool_parser:
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
                                    "arguments", {}))

                            # get what we've streamed so for for arguments
                            # for the current tool
                            actual_call = tool_parser.streamed_args_for_tool[
                                index]

                            # check to see if there's anything left to stream
                            remaining_call = expected_call.replace(
                                actual_call, "", 1)

                            # set that as a delta message
                            delta_message = DeltaMessage(tool_calls=[
                                DeltaToolCall(index=index,
                                              function=DeltaFunctionCall(
                                                  arguments=remaining_call).
                                              model_dump(exclude_none=True))
                            ])

502
503
504
505
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
506
                            delta=delta_message,
507
                            logprobs=logprobs,
508
509
510
511
                            finish_reason=output.finish_reason
                            if not (tool_parser
                                    and len(tool_parser.prev_tool_call_arr))
                            else "tool_calls",
512
                            stop_reason=output.stop_reason)
513
514
515
516
517
518
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
519
520
                        if (request.stream_options
                                and request.stream_options.include_usage):
521
522
523
524
525
526
527
528
529
530
531
532
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None
533
                        data = chunk.model_dump_json(exclude_unset=True)
534
535
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
536

537
538
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
557

558
559
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
560
            logger.error("error in chat completion stream generator: %s", e)
561
562
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
563
564
565
566
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
567
568
569
570
571
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
572
        tokenizer: AnyTokenizer,
573
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
574

575
        model_name = self.served_model_names[0]
576
        created_time = int(time.time())
577
        final_res: Optional[RequestOutput] = None
578

579
580
581
582
583
584
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")

585
586
        assert final_res is not None

587
        choices: List[ChatCompletionResponseChoice] = []
588

589
590
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
591
            token_ids = output.token_ids
592
            out_logprobs = output.logprobs
593

594
595
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
596
                logprobs = self._create_chat_logprobs(
597
                    token_ids=token_ids,
598
                    top_logprobs=out_logprobs,
599
                    num_output_top_logprobs=request.top_logprobs,
600
                    tokenizer=tokenizer,
601
602
603
604
                )
            else:
                logprobs = None

605
606
607
608
609
610
611
612
613
614
615
616
617
            # by default, tools are not used.
            tools_called = False

            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
            if not (self.enable_auto_tools
                    or not self.tool_parser) and not isinstance(
                        request.tool_choice,
                        ChatCompletionNamedToolChoiceParam):
                message = ChatMessage(role=role, content=output.text)

            # if the request uses tools and specified a tool choice
            elif request.tool_choice and type(
618
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
619

620
621
622
623
624
625
626
627
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
628
629
630
631
                tools_called = True

            # if the request doesn't use tool choice
            # OR specifies to not use a tool
632
            elif not request.tool_choice or request.tool_choice == "none":
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660

                message = ChatMessage(role=role, content=output.text)

            # handle when there are tools and tool choice is auto
            elif request.tools and (
                    request.tool_choice == "auto"
                    or request.tool_choice is None) and self.enable_auto_tools \
                    and self.tool_parser:

                tool_parser = self.tool_parser(tokenizer)
                tool_call_info = tool_parser.extract_tool_calls(output.text)
                tools_called = tool_call_info.tools_called
                if tool_call_info.tools_called:
                    message = ChatMessage(role=role,
                                          content=tool_call_info.content,
                                          tool_calls=tool_call_info.tool_calls)

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
                    message = ChatMessage(role=role, content=output.text)

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
                    "completion.")
661
662
                message = ChatMessage(role=role, content=output.text)

663
664
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
665
                message=message,
666
                logprobs=logprobs,
667
668
                finish_reason="tool_calls" if tools_called else
                output.finish_reason if output.finish_reason else "stop",
669
                stop_reason=output.stop_reason)
670
671
672
673
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
674
675
            if conversation and "content" in conversation[-1] and conversation[
                    -1].get("role") == role:
676
                last_msg_content = conversation[-1]["content"] or ""
677
678

            for choice in choices:
679
680
                full_message = last_msg_content + (choice.message.content
                                                   or "")
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
697
            prompt_logprobs=final_res.prompt_logprobs,
698
699
        )

700
        return response
701
702

    def _get_top_logprobs(
703
            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
704
            tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
705
        return [
706
707
708
709
710
711
712
713
            ChatCompletionLogProb(token=(token := self._get_decoded_token(
                p[1],
                p[0],
                tokenizer,
                return_as_token_id=self.return_tokens_as_token_ids)),
                                  logprob=max(p[1].logprob, -9999.0),
                                  bytes=list(
                                      token.encode("utf-8", errors="replace")))
714
715
716
717
718
719
720
721
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
722
        tokenizer: AnyTokenizer,
723
724
725
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
726
        logprobs_content: List[ChatCompletionLogProbsContent] = []
727
728
729
730

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
731
                token = tokenizer.decode(token_id)
732
733
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
734

735
736
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
737
                        token=token,
738
739
                        bytes=list(token.encode("utf-8", errors="replace")),
                    ))
740
            else:
741
742
743
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

744
745
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
746
                        token=self._get_decoded_token(
747
748
749
750
751
752
753
754
                            step_token,
                            token_id,
                            tokenizer,
                            self.return_tokens_as_token_ids,
                        ),
                        logprob=max(step_token.logprob, -9999.0),
                        bytes=None if step_decoded is None else list(
                            step_decoded.encode("utf-8", errors="replace")),
755
                        top_logprobs=self._get_top_logprobs(
756
757
758
759
760
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                        ),
                    ))
761
762

        return ChatCompletionLogProbs(content=logprobs_content)
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797

    def _should_stream_with_auto_tool_parsing(self,
                                              request: ChatCompletionRequest):
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
        return (request.tools and self.tool_parser and self.enable_auto_tools
                and request.tool_choice in ['auto', None])

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
        delta_message: Optional[DeltaMessage],
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        # yapf: disable
        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
            self.enable_auto_tools and self.tool_parser and delta_message
            and delta_message.tool_calls and delta_message.tool_calls[0]
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
            and output.finish_reason is not None
        )