serving_chat.py 39.6 KB
Newer Older
1
import asyncio
2
import json
3
import time
4
5
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
                    Optional)
6
from typing import Sequence as GenericSequence
7
from typing import Union
8

9
from fastapi import Request
10

11
from vllm.config import ModelConfig
12
from vllm.engine.async_llm_engine import AsyncLLMEngine
13
from vllm.engine.multiprocessing.client import MQLLMEngineClient
14
from vllm.engine.protocol import EngineClient
15
from vllm.entrypoints.chat_utils import (ConversationMessage,
16
17
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
18
                                         load_chat_template,
19
                                         parse_chat_messages_futures)
20
from vllm.entrypoints.logger import RequestLogger
21
from vllm.entrypoints.openai.protocol import (
22
23
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
24
    ChatCompletionRequest, ChatCompletionResponse,
25
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
26
    ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
27
28
    DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
    ToolCall, UsageInfo)
29
30
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
                                                    LoRAModulePath,
31
                                                    OpenAIServing,
32
33
                                                    PromptAdapterPath,
                                                    TextTokensPrompt)
34
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
35
from vllm.inputs import TokensPrompt
36
from vllm.logger import init_logger
37
from vllm.outputs import CompletionOutput, RequestOutput
38
from vllm.sampling_params import BeamSearchParams, SamplingParams
39
from vllm.sequence import Logprob
40
41
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
42
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
43
from vllm.utils import iterate_with_cancellation, random_uuid
44
45
46
47
48
49

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

50
    def __init__(self,
51
                 engine_client: EngineClient,
52
                 model_config: ModelConfig,
53
                 base_model_paths: List[BaseModelPath],
54
55
56
57
58
59
60
61
62
                 response_role: str,
                 *,
                 lora_modules: Optional[List[LoRAModulePath]],
                 prompt_adapters: Optional[List[PromptAdapterPath]],
                 request_logger: Optional[RequestLogger],
                 chat_template: Optional[str],
                 return_tokens_as_token_ids: bool = False,
                 enable_auto_tools: bool = False,
                 tool_parser: Optional[str] = None):
63
        super().__init__(engine_client=engine_client,
64
                         model_config=model_config,
65
                         base_model_paths=base_model_paths,
66
67
                         lora_modules=lora_modules,
                         prompt_adapters=prompt_adapters,
68
69
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
70

71
        self.response_role = response_role
72
        self.use_tool_use_model_template = False
73
        self.chat_template = load_chat_template(chat_template)
74

75
76
77
78
79
80
81
82
83
84
        # set up tool use
        self.enable_auto_tools: bool = enable_auto_tools
        if self.enable_auto_tools:
            logger.info(
                "\"auto\" tool choice has been enabled please note that while"
                " the parallel_tool_calls client option is preset for "
                "compatibility reasons, it will be ignored.")

        self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
        if self.enable_auto_tools:
85
86
87
88
            try:
                self.tool_parser = ToolParserManager.get_tool_parser(
                    tool_parser)
            except Exception as e:
89
                raise TypeError("Error: --enable-auto-tool-choice requires "
90
91
                                f"tool_parser:'{tool_parser}' which has not "
                                "been registered") from e
92

93
    async def create_chat_completion(
94
95
        self,
        request: ChatCompletionRequest,
96
97
98
        raw_request: Optional[Request] = None,
    ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
               ErrorResponse]:
99
100
        """Completion API similar to OpenAI's API.

101
102
103
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
104
105
106
107

        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
108
            logger.error("Error with model %s", error_check_ret)
109
110
            return error_check_ret

111
112
113
114
115
116
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

117
        try:
118
119
120
121
122
123
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            model_config = self.model_config
124
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
125

126
            conversation, mm_data_future = parse_chat_messages_futures(
127
                request.messages, model_config, tokenizer)
128

129
130
131
132
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

133
            prompt: Union[str, List[int]]
134
135
            is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
            if is_mistral_tokenizer:
136
137
138
139
140
                prompt = apply_mistral_chat_template(
                    tokenizer,
                    messages=request.messages,
                    chat_template=request.chat_template or self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
141
                    continue_final_message=request.continue_final_message,
142
143
144
145
146
147
148
149
150
151
                    tools=tool_dicts,
                    documents=request.documents,
                    **(request.chat_template_kwargs or {}),
                )
            else:
                prompt = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=request.chat_template or self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
152
                    continue_final_message=request.continue_final_message,
153
154
155
156
                    tools=tool_dicts,
                    documents=request.documents,
                    **(request.chat_template_kwargs or {}),
                )
157
        except Exception as e:
158
            logger.exception("Error in applying chat template from request")
159
160
            return self.create_error_response(str(e))

161
        try:
162
            mm_data = await mm_data_future
163
        except Exception as e:
164
            logger.exception("Error in loading multi-modal data")
165
166
            return self.create_error_response(str(e))

167
168
169
170
171
172
        # validation for OpenAI tools
        # tool_choice = "required" is not supported
        if request.tool_choice == "required":
            return self.create_error_response(
                "tool_choice = \"required\" is not supported!")

173
        if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
174
                self.enable_auto_tools and self.tool_parser is not None):
175
176
            # for hf tokenizers, "auto" tools requires
            # --enable-auto-tool-choice and --tool-call-parser
177
178
179
180
            return self.create_error_response(
                "\"auto\" tool choice requires "
                "--enable-auto-tool-choice and --tool-call-parser to be set")

181
        request_id = f"chat-{random_uuid()}"
182
183
184
185
186

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

187
        try:
188
189
190
191
            if self.enable_auto_tools and self.tool_parser:
                request = self.tool_parser(tokenizer).adjust_request(
                    request=request)

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            if isinstance(prompt, str):
                prompt_inputs = self._tokenize_prompt_input(
                    request,
                    tokenizer,
                    prompt,
                    truncate_prompt_tokens=request.truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
                assert isinstance(prompt, list) and isinstance(
                    prompt[0], int
                ), "Prompt has to be either a string or a list of token ids"
                prompt_inputs = TextTokensPrompt(
                    prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)

            assert prompt_inputs is not None
208

209
210
211
212
213
214
215
216
217
            sampling_params: Union[SamplingParams, BeamSearchParams]
            default_max_tokens = self.max_model_len - len(
                prompt_inputs["prompt_token_ids"])
            if request.use_beam_search:
                sampling_params = request.to_beam_search_params(
                    default_max_tokens)
            else:
                sampling_params = request.to_sampling_params(
                    default_max_tokens)
218

219
220
221
222
223
224
            self._log_inputs(request_id,
                             prompt_inputs,
                             params=sampling_params,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)

225
226
            engine_inputs = TokensPrompt(
                prompt_token_ids=prompt_inputs["prompt_token_ids"])
227
228
229
            if mm_data is not None:
                engine_inputs["multi_modal_data"] = mm_data

230
231
            is_tracing_enabled = (await
                                  self.engine_client.is_tracing_enabled())
232
233
234
235
236
237
238
            trace_headers = None
            if is_tracing_enabled and raw_request:
                trace_headers = extract_trace_headers(raw_request.headers)
            if (not is_tracing_enabled and raw_request
                    and contains_trace_headers(raw_request.headers)):
                log_tracing_disabled_warning()

239
            if isinstance(sampling_params, BeamSearchParams):
240
241
242
243
244
                assert isinstance(self.engine_client,
                                    (AsyncLLMEngine,
                                    MQLLMEngineClient)), \
                    "Beam search is only supported with" \
                    "AsyncLLMEngine and MQLLMEngineClient."
245
                result_generator = self.engine_client.beam_search(
246
247
248
249
                    engine_inputs['prompt_token_ids'],
                    request_id,
                    sampling_params,
                )
250
251
252
253
254
255
256
257
258
259
            else:
                result_generator = self.engine_client.generate(
                    engine_inputs,
                    sampling_params,
                    request_id,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
                    priority=request.priority,
                )
260
        except ValueError as e:
261
            # TODO: Use a vllm-specific Validation Error
262
263
            return self.create_error_response(str(e))

264
265
266
267
        if raw_request:
            result_generator = iterate_with_cancellation(
                result_generator, raw_request.is_disconnected)

268
269
270
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
271
272
                request, result_generator, request_id, conversation, tokenizer,
                request_metadata)
273

274
275
        try:
            return await self.chat_completion_full_generator(
276
277
                request, result_generator, request_id, conversation, tokenizer,
                request_metadata)
278
279
280
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
281
282
283
284

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
285
        return request.messages[-1]["role"]
286
287

    async def chat_completion_stream_generator(
288
289
290
291
292
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
293
        tokenizer: AnyTokenizer,
294
        request_metadata: RequestResponseMetadata,
295
    ) -> AsyncGenerator[str, None]:
296
        model_name = self.base_model_paths[0].name
297
        created_time = int(time.time())
298
        chunk_object_type: Final = "chat.completion.chunk"
299
        first_iteration = True
300
301

        # Send response for each token for each request.n (index)
302
303
304
        num_choices = 1 if request.n is None else request.n
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        num_prompt_tokens = 0

        if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
            tool_choice_function_name = request.tool_choice.function.name
        else:
            tool_choice_function_name = None

        # Determine whether tools are in use with "auto" tool choice
        tool_choice_auto = (
            not tool_choice_function_name
            and self._should_stream_with_auto_tool_parsing(request))

        all_previous_token_ids: Optional[List[List[int]]]
        if tool_choice_auto:
            # These are only required in "auto" tool choice case
            previous_texts = [""] * num_choices
            all_previous_token_ids = [[]] * num_choices
        else:
            previous_texts, all_previous_token_ids = None, None

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        # Prepare the tool parser if it's needed
        try:
            if tool_choice_auto and self.tool_parser:
                tool_parsers: List[Optional[ToolParser]] = [
                    self.tool_parser(tokenizer)
                ] * num_choices
            else:
                tool_parsers = [None] * num_choices
        except RuntimeError as e:
            logger.error("Error in tool parser creation: %s", e)
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
            yield "data: [DONE]\n\n"
            return

340
341
        try:
            async for res in result_generator:
342
343
                if res.prompt_token_ids is not None:
                    num_prompt_tokens = len(res.prompt_token_ids)
344
345
                    if res.encoder_prompt_token_ids is not None:
                        num_prompt_tokens += len(res.encoder_prompt_token_ids)
346

347
348
349
350
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
351
352
                    # Send first response for each request.n (index) with
                    # the role
353
                    role = self.get_chat_request_role(request)
354
355
356

                    # NOTE num_choices defaults to 1 so this usually executes
                    # once per request
357
                    for i in range(num_choices):
358
                        tool_parser = tool_parsers[i]
359
360
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
361
362
363
364
                            delta=DeltaMessage(
                                role=role,
                                content="",
                            ),
365
366
367
368
369
370
371
372
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
373
374

                        # if usage should be included
375
376
                        if (request.stream_options
                                and request.stream_options.include_usage):
377
378
                            # if continuous usage stats are requested, add it
                            if request.stream_options.continuous_usage_stats:
379
380
381
382
                                usage = UsageInfo(
                                    prompt_tokens=num_prompt_tokens,
                                    completion_tokens=0,
                                    total_tokens=num_prompt_tokens)
383
                                chunk.usage = usage
384
                            # otherwise don't
385
386
387
                            else:
                                chunk.usage = None

388
389
390
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

391
392
                    # Send response to echo the input portion of the
                    # last message
393
                    if request.echo or request.continue_final_message:
394
395
396
397
                        last_msg_content: str = ""
                        if conversation and "content" in conversation[
                                -1] and conversation[-1].get("role") == role:
                            last_msg_content = conversation[-1]["content"] or ""
398
399

                        if last_msg_content:
400
                            for i in range(num_choices):
401
402
403
404
405
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
406
                                        logprobs=None,
407
                                        finish_reason=None))
408
409
410
411
412
413
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    model=model_name)
414
415
                                if (request.stream_options and
                                        request.stream_options.include_usage):
416
417
418
                                    if (request.stream_options.
                                            continuous_usage_stats):
                                        usage = UsageInfo(
419
                                            prompt_tokens=num_prompt_tokens,
420
                                            completion_tokens=0,
421
                                            total_tokens=num_prompt_tokens)
422
423
424
425
                                        chunk.usage = usage
                                    else:
                                        chunk.usage = None

426
427
428
429
430
431
432
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index
433
                    tool_parser = tool_parsers[i]
434
435
436
437

                    if finish_reason_sent[i]:
                        continue

438
                    if request.logprobs and request.top_logprobs is not None:
439
                        assert output.logprobs is not None, (
440
                            "Did not output logprobs")
441
                        logprobs = self._create_chat_logprobs(
442
443
                            token_ids=output.token_ids,
                            top_logprobs=output.logprobs,
444
                            tokenizer=tokenizer,
445
                            num_output_top_logprobs=request.top_logprobs,
446
447
448
449
                        )
                    else:
                        logprobs = None

450
451
                    delta_text = output.text
                    delta_message: Optional[DeltaMessage]
452

453
                    # handle streaming deltas for tools with named tool_choice
454
                    if tool_choice_function_name:
455
                        delta_message = DeltaMessage(tool_calls=[
456
                            DeltaToolCall(function=DeltaFunctionCall(
457
                                name=tool_choice_function_name,
458
459
                                arguments=delta_text),
                                          index=i)
460
                        ])
461
462

                    # handle streaming deltas for tools with "auto" tool choice
463
464
465
466
467
468
469
470
471
472
473
                    elif tool_choice_auto:
                        assert previous_texts is not None
                        assert all_previous_token_ids is not None
                        assert tool_parser is not None
                        #TODO optimize manipulation of these lists
                        previous_text = previous_texts[i]
                        previous_token_ids = all_previous_token_ids[i]
                        current_text = previous_text + delta_text
                        current_token_ids = previous_token_ids + list(
                            output.token_ids)

474
475
                        delta_message = (
                            tool_parser.extract_tool_calls_streaming(
476
477
                                previous_text=previous_text,
                                current_text=current_text,
478
                                delta_text=delta_text,
479
480
                                previous_token_ids=previous_token_ids,
                                current_token_ids=current_token_ids,
481
482
                                delta_token_ids=output.token_ids,
                                request=request))
483
484
485
486

                        # update the previous values for the next iteration
                        previous_texts[i] = current_text
                        all_previous_token_ids[i] = current_token_ids
487
488

                    # handle streaming just a content delta
489
490
491
                    else:
                        delta_message = DeltaMessage(content=delta_text)

492
                    # set the previous values for the next iteration
493
                    previous_num_tokens[i] += len(output.token_ids)
494
495
496
497
498
499
500
501

                    # if the message delta is None (e.g. because it was a
                    # "control token" for tool calls or the parser otherwise
                    # wasn't ready to send a token, then
                    #   get the next token without streaming a chunk
                    if delta_message is None:
                        continue

502
503
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
504

505
506
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
507
                            delta=delta_message,
508
509
510
511
512
513
514
515
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
516
517

                        # handle usage stats if requested & if continuous
518
519
                        if (request.stream_options
                                and request.stream_options.include_usage):
520
                            if request.stream_options.continuous_usage_stats:
521
522
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
523
                                    prompt_tokens=num_prompt_tokens,
524
                                    completion_tokens=completion_tokens,
525
                                    total_tokens=num_prompt_tokens +
526
527
528
529
530
531
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None

532
533
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
534
535

                    # if the model is finished generating
536
                    else:
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
                        # check to make sure we haven't "forgotten" to stream
                        #   any tokens that were generated but previously
                        #   matched by partial json parsing
                        # only happens if we are NOT using guided decoding
                        if tool_parser:
                            index = len(
                                tool_parser.prev_tool_call_arr) - 1 if len(
                                    tool_parser.prev_tool_call_arr) > 0 else 0
                        else:
                            index = 0

                        if self._should_check_for_unstreamed_tool_arg_tokens(
                                delta_message, output) and tool_parser:
                            # get the expected call based on partial JSON
                            # parsing which "autocompletes" the JSON
                            expected_call = json.dumps(
                                tool_parser.prev_tool_call_arr[index].get(
                                    "arguments", {}))

556
                            # get what we've streamed so far for arguments
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                            # for the current tool
                            actual_call = tool_parser.streamed_args_for_tool[
                                index]

                            # check to see if there's anything left to stream
                            remaining_call = expected_call.replace(
                                actual_call, "", 1)

                            # set that as a delta message
                            delta_message = DeltaMessage(tool_calls=[
                                DeltaToolCall(index=index,
                                              function=DeltaFunctionCall(
                                                  arguments=remaining_call).
                                              model_dump(exclude_none=True))
                            ])

573
574
575
                        # Send the finish response for each request.n only once
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
576
                            delta=delta_message,
577
                            logprobs=logprobs,
578
579
580
581
                            finish_reason=output.finish_reason
                            if not (tool_parser
                                    and len(tool_parser.prev_tool_call_arr))
                            else "tool_calls",
582
                            stop_reason=output.stop_reason)
583
584
585
586
587
588
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
589
590
                        if (request.stream_options
                                and request.stream_options.include_usage):
591
                            if request.stream_options.continuous_usage_stats:
592
593
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
594
                                    prompt_tokens=num_prompt_tokens,
595
                                    completion_tokens=completion_tokens,
596
                                    total_tokens=num_prompt_tokens +
597
598
599
600
601
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None
602
                        data = chunk.model_dump_json(exclude_unset=True)
603
604
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
605

606
607
            # once the final token is handled, if stream_options.include_usage
            # is sent, send the usage
608
609
            if (request.stream_options
                    and request.stream_options.include_usage):
610
                completion_tokens = previous_num_tokens[i]
611
                final_usage = UsageInfo(
612
613
614
                    prompt_tokens=num_prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=num_prompt_tokens + completion_tokens,
615
616
617
618
619
620
621
622
623
624
625
626
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
627

628
629
630
631
632
633
634
            # report to FastAPI middleware aggregate usage across all choices
            num_completion_tokens = sum(previous_num_tokens)
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                completion_tokens=num_completion_tokens,
                total_tokens=num_prompt_tokens + num_completion_tokens)

635
636
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
637
            logger.error("error in chat completion stream generator: %s", e)
638
639
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
640
641
642
643
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
644
645
646
647
648
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
649
        tokenizer: AnyTokenizer,
650
        request_metadata: RequestResponseMetadata,
651
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
652

653
        model_name = self.base_model_paths[0].name
654
        created_time = int(time.time())
655
        final_res: Optional[RequestOutput] = None
656

657
658
659
660
661
662
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")

663
664
        assert final_res is not None

665
        choices: List[ChatCompletionResponseChoice] = []
666

667
668
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
669
            token_ids = output.token_ids
670
            out_logprobs = output.logprobs
671

672
673
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
674
                logprobs = self._create_chat_logprobs(
675
                    token_ids=token_ids,
676
                    top_logprobs=out_logprobs,
677
                    num_output_top_logprobs=request.top_logprobs,
678
                    tokenizer=tokenizer,
679
680
681
682
                )
            else:
                logprobs = None

683
684
685
686
687
            # by default, tools are not used.
            tools_called = False

            # if auto tools are not enabled, and a named tool choice using
            #   outlines is not being used
688
            if (not self.enable_auto_tools
689
690
691
692
693
694
695
                    or not self.tool_parser) and not isinstance(
                        request.tool_choice,
                        ChatCompletionNamedToolChoiceParam):
                message = ChatMessage(role=role, content=output.text)

            # if the request uses tools and specified a tool choice
            elif request.tool_choice and type(
696
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
697

698
699
700
701
702
703
704
705
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
706
707
708
709
                tools_called = True

            # if the request doesn't use tool choice
            # OR specifies to not use a tool
710
            elif not request.tool_choice or request.tool_choice == "none":
711
712
713
714
715
716
717
718
719

                message = ChatMessage(role=role, content=output.text)

            # handle when there are tools and tool choice is auto
            elif request.tools and (
                    request.tool_choice == "auto"
                    or request.tool_choice is None) and self.enable_auto_tools \
                    and self.tool_parser:

720
721
722
723
724
725
                try:
                    tool_parser = self.tool_parser(tokenizer)
                except RuntimeError as e:
                    logger.error("Error in tool parser creation: %s", e)
                    return self.create_error_response(str(e))

726
727
                tool_call_info = tool_parser.extract_tool_calls(
                    output.text, request=request)
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
                tools_called = tool_call_info.tools_called
                if tool_call_info.tools_called:
                    message = ChatMessage(role=role,
                                          content=tool_call_info.content,
                                          tool_calls=tool_call_info.tool_calls)

                else:
                    # FOR NOW make it a chat message; we will have to detect
                    # the type to make it later.
                    message = ChatMessage(role=role, content=output.text)

            # undetermined case that is still important to handle
            else:
                logger.error(
                    "Error in chat_completion_full_generator - cannot determine"
                    " if tools should be extracted. Returning a standard chat "
                    "completion.")
745
746
                message = ChatMessage(role=role, content=output.text)

747
748
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
749
                message=message,
750
                logprobs=logprobs,
751
752
                finish_reason="tool_calls" if tools_called else
                output.finish_reason if output.finish_reason else "stop",
753
                stop_reason=output.stop_reason)
754
755
            choices.append(choice_data)

756
        if request.echo or request.continue_final_message:
757
            last_msg_content = ""
758
759
            if conversation and "content" in conversation[-1] and conversation[
                    -1].get("role") == role:
760
                last_msg_content = conversation[-1]["content"] or ""
761
762

            for choice in choices:
763
764
                full_message = last_msg_content + (choice.message.content
                                                   or "")
765
766
                choice.message.content = full_message

767
        assert final_res.prompt_token_ids is not None
768
        num_prompt_tokens = len(final_res.prompt_token_ids)
769
770
        if final_res.encoder_prompt_token_ids is not None:
            num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
771
772
773
774
775
776
777
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
778
779
780

        request_metadata.final_usage_info = usage

781
782
783
784
785
786
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
787
            prompt_logprobs=final_res.prompt_logprobs,
788
789
        )

790
        return response
791
792

    def _get_top_logprobs(
793
            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
794
            tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
795
        return [
796
797
798
799
800
801
802
803
            ChatCompletionLogProb(token=(token := self._get_decoded_token(
                p[1],
                p[0],
                tokenizer,
                return_as_token_id=self.return_tokens_as_token_ids)),
                                  logprob=max(p[1].logprob, -9999.0),
                                  bytes=list(
                                      token.encode("utf-8", errors="replace")))
804
805
806
807
808
809
810
811
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
812
        tokenizer: AnyTokenizer,
813
814
815
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""
816
        logprobs_content: List[ChatCompletionLogProbsContent] = []
817
818
819
820

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
821
                token = tokenizer.decode(token_id)
822
823
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
824

825
826
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
827
                        token=token,
828
829
                        bytes=list(token.encode("utf-8", errors="replace")),
                    ))
830
            else:
831
832
833
                step_token = step_top_logprobs[token_id]
                step_decoded = step_token.decoded_token

834
835
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
836
                        token=self._get_decoded_token(
837
838
839
840
841
842
843
844
                            step_token,
                            token_id,
                            tokenizer,
                            self.return_tokens_as_token_ids,
                        ),
                        logprob=max(step_token.logprob, -9999.0),
                        bytes=None if step_decoded is None else list(
                            step_decoded.encode("utf-8", errors="replace")),
845
                        top_logprobs=self._get_top_logprobs(
846
847
848
849
850
                            step_top_logprobs,
                            num_output_top_logprobs,
                            tokenizer,
                        ),
                    ))
851
852

        return ChatCompletionLogProbs(content=logprobs_content)
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881

    def _should_stream_with_auto_tool_parsing(self,
                                              request: ChatCompletionRequest):
        """
        Utility function to check if streamed tokens should go through the tool
        call parser that was configured.

        We only want to do this IF user-provided tools are set, a tool parser
        is configured, "auto" tool choice is enabled, and the request's tool
        choice field indicates that "auto" tool choice should be used.
        """
        return (request.tools and self.tool_parser and self.enable_auto_tools
                and request.tool_choice in ['auto', None])

    def _should_check_for_unstreamed_tool_arg_tokens(
        self,
        delta_message: Optional[DeltaMessage],
        output: CompletionOutput,
    ) -> bool:
        """
        Check to see if we should check for unstreamed tool arguments tokens.
        This is only applicable when auto tool parsing is enabled, the delta
        is a tool call with arguments.
        """

        # yapf: disable
        return bool(
            # if there is a delta message that includes tool calls which
            # include a function that has arguments
882
883
            output.finish_reason is not None
            and self.enable_auto_tools and self.tool_parser and delta_message
884
885
886
887
            and delta_message.tool_calls and delta_message.tool_calls[0]
            and delta_message.tool_calls[0].function
            and delta_message.tool_calls[0].function.arguments is not None
        )