serving_chat.py 25 KB
Newer Older
1
import asyncio
2
import time
3
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
4
from typing import Sequence as GenericSequence
5
from typing import Union
6

7
from fastapi import Request
8
from transformers import PreTrainedTokenizer
9

10
from vllm.config import ModelConfig
11
from vllm.engine.protocol import AsyncEngineClient
12
from vllm.entrypoints.chat_utils import (ConversationMessage,
13
                                         apply_chat_template,
14
                                         load_chat_template,
15
                                         parse_chat_messages)
16
from vllm.entrypoints.logger import RequestLogger
17
from vllm.entrypoints.openai.protocol import (
18
19
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
20
    ChatCompletionRequest, ChatCompletionResponse,
21
22
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
23
    FunctionCall, ToolCall, UsageInfo)
24
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
25
26
                                                    OpenAIServing,
                                                    PromptAdapterPath)
27
from vllm.inputs import PromptInputs
28
from vllm.logger import init_logger
29
from vllm.multimodal import MultiModalDataDict
30
from vllm.outputs import RequestOutput
31
from vllm.sequence import Logprob
32
33
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
34
from vllm.utils import iterate_with_cancellation, random_uuid
35
36
37
38
39
40

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

41
42
    def __init__(
        self,
43
        async_engine_client: AsyncEngineClient,
44
45
46
47
48
49
50
51
        model_config: ModelConfig,
        served_model_names: List[str],
        response_role: str,
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        prompt_adapters: Optional[List[PromptAdapterPath]],
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
52
        return_tokens_as_token_ids: bool = False,
53
    ):
54
        super().__init__(async_engine_client=async_engine_client,
55
                         model_config=model_config,
56
                         served_model_names=served_model_names,
57
58
                         lora_modules=lora_modules,
                         prompt_adapters=prompt_adapters,
59
60
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
61

62
        self.response_role = response_role
63
64
65

        # If this is None we use the tokenizer's default chat template
        self.chat_template = load_chat_template(chat_template)
66

67
    async def create_chat_completion(
68
69
70
        self,
        request: ChatCompletionRequest,
        raw_request: Optional[Request] = None
71
72
73
74
    ) -> Union[ErrorResponse, AsyncGenerator[str, None],
               ChatCompletionResponse]:
        """Completion API similar to OpenAI's API.

75
76
77
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
78

79
        NOTE: Currently we do not support the following feature:
80
81
82
83
84
85
            - function_call (Users should implement this by themselves)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

86
87
88
89
90
91
92
93
94
95
        if request.prompt_logprobs is not None:
            if request.stream and request.prompt_logprobs > 0:
                return self.create_error_response(
                    "Prompt_logprobs are not available when stream is enabled")

            if request.prompt_logprobs < 0:
                return self.create_error_response(
                    f"Prompt_logprobs set to invalid "
                    f"negative value: {request.prompt_logprobs}")

96
        try:
97
98
99
100
101
102
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

            model_config = self.model_config
103
104
            tokenizer = await self.async_engine_client.get_tokenizer(
                lora_request)
105

106
107
            conversation, mm_futures = parse_chat_messages(
                request.messages, model_config, tokenizer)
108

109
110
111
112
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

113
114
            prompt = apply_chat_template(
                tokenizer,
115
                conversation=conversation,
116
                chat_template=request.chat_template or self.chat_template,
117
                add_generation_prompt=request.add_generation_prompt,
118
119
120
                tools=tool_dicts,
                documents=request.documents,
                **(request.chat_template_kwargs or {}),
121
            )
122
        except Exception as e:
123
            logger.error("Error in applying chat template from request: %s", e)
124
125
            return self.create_error_response(str(e))

126
        mm_data: Optional[MultiModalDataDict] = None
127
        try:
128
129
            if len(mm_futures):
                # since we support only single mm data currently
130
131
132
                assert len(
                    mm_futures
                ) == 1, "Multiple 'image_url' input is currently not supported."
133
                mm_data = await mm_futures[0]
134
        except Exception as e:
135
            logger.error("Error in loading multi-modal data: %s", e)
136
137
            return self.create_error_response(str(e))

138
        request_id = f"chat-{random_uuid()}"
139
        try:
140
            guided_decode_logits_processor = (
141
                await self._guided_decode_logits_processor(request, tokenizer))
142
143
144
145
146

            prompt_inputs = self._tokenize_prompt_input(
                request,
                tokenizer,
                prompt,
147
                truncate_prompt_tokens=request.truncate_prompt_tokens,
148
149
150
                add_special_tokens=request.add_special_tokens,
            )

151
152
153
154
155
156
            sampling_params = request.to_sampling_params(
                tokenizer,
                guided_decode_logits_processor,
                default_max_tokens=self.max_model_len -
                len(prompt_inputs["prompt_token_ids"]))

157
158
159
160
161
162
163
164
165
166
167
168
            self._log_inputs(request_id,
                             prompt_inputs,
                             params=sampling_params,
                             lora_request=lora_request,
                             prompt_adapter_request=prompt_adapter_request)

            engine_inputs: PromptInputs = {
                "prompt_token_ids": prompt_inputs["prompt_token_ids"],
            }
            if mm_data is not None:
                engine_inputs["multi_modal_data"] = mm_data

169
170
            is_tracing_enabled = (
                await self.async_engine_client.is_tracing_enabled())
171
172
173
174
175
176
177
            trace_headers = None
            if is_tracing_enabled and raw_request:
                trace_headers = extract_trace_headers(raw_request.headers)
            if (not is_tracing_enabled and raw_request
                    and contains_trace_headers(raw_request.headers)):
                log_tracing_disabled_warning()

178
            result_generator = self.async_engine_client.generate(
179
180
181
182
183
184
185
                engine_inputs,
                sampling_params,
                request_id,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
            )
186
        except ValueError as e:
187
            # TODO: Use a vllm-specific Validation Error
188
189
            return self.create_error_response(str(e))

190
191
192
193
        if raw_request:
            result_generator = iterate_with_cancellation(
                result_generator, raw_request.is_disconnected)

194
195
196
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
197
                request, result_generator, request_id, conversation, tokenizer)
198
199
200
201
202
203
        try:
            return await self.chat_completion_full_generator(
                request, result_generator, request_id, conversation, tokenizer)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
204
205
206
207
208

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
209
            return request.messages[-1]["role"]
210
211

    async def chat_completion_stream_generator(
212
213
214
215
216
217
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
        tokenizer: PreTrainedTokenizer,
218
    ) -> AsyncGenerator[str, None]:
219
        model_name = self.served_model_names[0]
220
        created_time = int(time.time())
221
        chunk_object_type = "chat.completion.chunk"
222
        first_iteration = True
223
224

        # Send response for each token for each request.n (index)
225
226
227
228
229
        num_choices = 1 if request.n is None else request.n
        previous_texts = [""] * num_choices
        previous_num_tokens = [0] * num_choices
        finish_reason_sent = [False] * num_choices

230
231
232
233
234
235
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
236
237
                    # Send first response for each request.n (index) with
                    # the role
238
                    role = self.get_chat_request_role(request)
239
                    for i in range(num_choices):
240
241
242
243
244
245
246
247
248
249
250
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
                            delta=DeltaMessage(role=role),
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
251
252
                        if (request.stream_options
                                and request.stream_options.include_usage):
253
254
255
256
257
258
259
260
261
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                usage = UsageInfo(prompt_tokens=prompt_tokens,
                                                  completion_tokens=0,
                                                  total_tokens=prompt_tokens)
                                chunk.usage = usage
                            else:
                                chunk.usage = None

262
263
264
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

265
266
                    # Send response to echo the input portion of the
                    # last message
267
268
                    if request.echo:
                        last_msg_content = ""
269
270
271
272
                        if conversation and conversation[-1].get(
                                "content") and conversation[-1].get(
                                    "role") == role:
                            last_msg_content = conversation[-1]["content"]
273
274

                        if last_msg_content:
275
                            for i in range(num_choices):
276
277
278
279
280
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
281
                                        logprobs=None,
282
                                        finish_reason=None))
283
284
285
286
287
288
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    model=model_name)
289
290
                                if (request.stream_options and
                                        request.stream_options.include_usage):
291
292
293
294
295
296
297
298
299
300
301
302
                                    if (request.stream_options.
                                            continuous_usage_stats):
                                        prompt_tokens = len(
                                            res.prompt_token_ids)
                                        usage = UsageInfo(
                                            prompt_tokens=prompt_tokens,
                                            completion_tokens=0,
                                            total_tokens=prompt_tokens)
                                        chunk.usage = usage
                                    else:
                                        chunk.usage = None

303
304
305
306
307
308
309
310
311
312
313
314
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
315
                    out_logprobs = output.logprobs[
316
317
                        previous_num_tokens[i]:] if output.logprobs else None

318
319
320
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
321
                        logprobs = self._create_chat_logprobs(
322
                            token_ids=delta_token_ids,
323
                            top_logprobs=out_logprobs,
324
                            tokenizer=tokenizer,
325
                            num_output_top_logprobs=request.top_logprobs,
326
327
328
329
330
331
332
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
333
334
335
336
337
338
339
340
341
342
343
344

                    if request.tool_choice and type(
                            request.tool_choice
                    ) is ChatCompletionNamedToolChoiceParam:
                        delta_message = DeltaMessage(tool_calls=[
                            ToolCall(function=FunctionCall(
                                name=request.tool_choice.function.name,
                                arguments=delta_text))
                        ])
                    else:
                        delta_message = DeltaMessage(content=delta_text)

345
346
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
347

348
349
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
350
                            delta=delta_message,
351
352
353
354
355
356
357
358
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
359
360
                        if (request.stream_options
                                and request.stream_options.include_usage):
361
362
363
364
365
366
367
368
369
370
371
372
373
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None

374
375
376
377
378
379
380
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
                    else:
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
381
                            delta=delta_message,
382
                            logprobs=logprobs,
383
384
                            finish_reason=output.finish_reason,
                            stop_reason=output.stop_reason)
385
386
387
388
389
390
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
391
392
                        if (request.stream_options
                                and request.stream_options.include_usage):
393
394
395
396
397
398
399
400
401
402
403
404
                            if (request.stream_options.continuous_usage_stats):
                                prompt_tokens = len(res.prompt_token_ids)
                                completion_tokens = len(output.token_ids)
                                usage = UsageInfo(
                                    prompt_tokens=prompt_tokens,
                                    completion_tokens=completion_tokens,
                                    total_tokens=prompt_tokens +
                                    completion_tokens,
                                )
                                chunk.usage = usage
                            else:
                                chunk.usage = None
405
                        data = chunk.model_dump_json(exclude_unset=True)
406
407
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
408

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
427

428
429
430
431
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
432
433
434
435
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
436
437
438
439
440
441
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
        tokenizer: PreTrainedTokenizer,
442
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
443

444
        model_name = self.served_model_names[0]
445
        created_time = int(time.time())
446
        final_res: Optional[RequestOutput] = None
447

448
449
450
451
452
453
        try:
            async for res in result_generator:
                final_res = res
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")

454
455
        assert final_res is not None

456
        choices: List[ChatCompletionResponseChoice] = []
457

458
459
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
460
            token_ids = output.token_ids
461
            out_logprobs = output.logprobs
462

463
464
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
465
                logprobs = self._create_chat_logprobs(
466
                    token_ids=token_ids,
467
                    top_logprobs=out_logprobs,
468
                    num_output_top_logprobs=request.top_logprobs,
469
                    tokenizer=tokenizer,
470
471
472
473
                )
            else:
                logprobs = None

474
475
476
477
478
479
480
481
482
483
484
485
486
            if request.tool_choice and type(
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
            elif not request.tool_choice or request.tool_choice == "none":
                message = ChatMessage(role=role, content=output.text)

487
488
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
489
                message=message,
490
                logprobs=logprobs,
491
                finish_reason=output.finish_reason,
492
                stop_reason=output.stop_reason)
493
494
495
496
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
497
498
499
            if conversation and conversation[-1].get(
                    "content") and conversation[-1].get("role") == role:
                last_msg_content = conversation[-1]["content"]
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518

            for choice in choices:
                full_message = last_msg_content + choice.message.content
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
519
            prompt_logprobs=final_res.prompt_logprobs,
520
521
        )

522
        return response
523
524

    def _get_top_logprobs(
525
526
            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
            tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
527
        return [
528
529
530
531
532
533
534
535
            ChatCompletionLogProb(token=(token := self._get_decoded_token(
                p[1],
                p[0],
                tokenizer,
                return_as_token_id=self.return_tokens_as_token_ids)),
                                  logprob=max(p[1].logprob, -9999.0),
                                  bytes=list(
                                      token.encode("utf-8", errors="replace")))
536
537
538
539
540
541
542
543
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
544
        tokenizer: PreTrainedTokenizer,
545
546
547
548
549
550
551
552
553
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""

        logprobs_content = []

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
554
                token = tokenizer.decode(token_id)
555
556
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
557
558
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
559
560
                        token=token,
                        bytes=list(token.encode("utf-8", errors="replace"))))
561
562
563
            else:
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
564
565
566
                        token=self._get_decoded_token(
                            step_top_logprobs[token_id], token_id, tokenizer,
                            self.return_tokens_as_token_ids),
567
568
569
570
571
572
                        logprob=max(step_top_logprobs[token_id].logprob,
                                    -9999.0),
                        bytes=list(
                            step_top_logprobs[token_id].decoded_token.encode(
                                "utf-8", errors="replace")),
                        top_logprobs=self._get_top_logprobs(
573
574
                            step_top_logprobs, num_output_top_logprobs,
                            tokenizer)))
575
576

        return ChatCompletionLogProbs(content=logprobs_content)