serving_chat.py 26.9 KB
Newer Older
1
import codecs
2
import time
3
from dataclasses import dataclass, field
4
from functools import cached_property
5
6
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
                    List, Optional)
7
8
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
9

10
from fastapi import Request
11
12
from openai.types.chat import (ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartTextParam)
13

14
from vllm.config import ModelConfig
15
16
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
17
18
    ChatCompletionContentPartParam, ChatCompletionLogProb,
    ChatCompletionLogProbs, ChatCompletionLogProbsContent,
19
20
    ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
    ChatCompletionRequest, ChatCompletionResponse,
21
22
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
23
    FunctionCall, ToolCall, UsageInfo)
24
25
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
26
from vllm.inputs import PromptInputs
27
from vllm.logger import init_logger
28
29
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
30
from vllm.multimodal import MultiModalDataDict
31
from vllm.multimodal.utils import async_get_and_parse_image
32
from vllm.outputs import RequestOutput
33
from vllm.sequence import Logprob
34
35
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
36
from vllm.utils import random_uuid
37
38
39
40

logger = init_logger(__name__)


41
42
43
44
45
46
@final  # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
    role: str
    content: str


47
48
49
@dataclass(frozen=True)
class ChatMessageParseResult:
    messages: List[ConversationMessage]
50
    mm_futures: List[Awaitable[MultiModalDataDict]] = field(
51
        default_factory=list)
52
53


54
55
56
57
class OpenAIServingChat(OpenAIServing):

    def __init__(self,
                 engine: AsyncLLMEngine,
58
                 model_config: ModelConfig,
59
                 served_model_names: List[str],
60
                 response_role: str,
61
62
                 lora_modules: Optional[List[LoRAModulePath]] = None,
                 chat_template: Optional[str] = None):
63
        super().__init__(engine=engine,
64
                         model_config=model_config,
65
                         served_model_names=served_model_names,
66
                         lora_modules=lora_modules)
67

68
        self.response_role = response_role
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        self._load_chat_template(chat_template)

    def _load_chat_template(self, chat_template: Optional[str]):
        tokenizer = self.tokenizer

        if chat_template is not None:
            try:
                with open(chat_template, "r") as f:
                    tokenizer.chat_template = f.read()
            except OSError as e:
                JINJA_CHARS = "{}\n"
                if not any(c in chat_template for c in JINJA_CHARS):
                    msg = (f"The supplied chat template ({chat_template}) "
                           f"looks like a file path, but it failed to be "
                           f"opened. Reason: {e}")
                    raise ValueError(msg) from e

                # If opening a file fails, set chat template to be args to
                # ensure we decode so our escape are interpreted correctly
                tokenizer.chat_template = codecs.decode(
                    chat_template, "unicode_escape")

            logger.info("Using supplied chat template:\n%s",
                        tokenizer.chat_template)
        elif tokenizer.chat_template is not None:
            logger.info("Using default chat template:\n%s",
                        tokenizer.chat_template)
        else:
            logger.warning(
                "No chat template provided. Chat API will not work.")
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    @cached_property
    def image_token_str(self) -> Optional[str]:
        # TODO: Let user specify how to insert image tokens into prompt
        # (similar to chat template)
        model_type = self.model_config.hf_config.model_type
        if model_type == "phi3_v":
            # Workaround since this token is not defined in the tokenizer
            return "<|image_1|>"
        if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
                          "paligemma"):
            # These models do not use image tokens in the prompt
            return None

        # The default behaviour assumes that the image token is
        # available to the tokenizer.
        # (Suitable for LLaVA, Idefics2, DeepSeek-VL)
        vlm_config = self.model_config.multimodal_config
        if vlm_config is None:
            raise ValueError(
                "'image_url' input is not supported as the loaded "
                "model is not multimodal.")

        image_token_id = vlm_config.image_token_id
        if vlm_config.image_token_id is None:
            raise ValueError(
                "'image_url' input is not supported as the loaded "
                "model does not specify an image token.")

        return self.tokenizer.decode(image_token_id)

130
131
132
133
134
135
136
137
138
139
    # TODO: Let user specify how to insert image tokens into prompt
    # (similar to chat template)
    def _get_full_image_text_prompt(self, image_token_str: str,
                                    text_prompt: str) -> str:
        """Combine image and text prompts for vision language model"""

        # NOTE: For now we assume all model architectures use the same
        # image + text prompt format. This may change in the future.
        return f"{image_token_str}\n{text_prompt}"

140
    def _parse_chat_message_content_parts(
141
        self,
142
143
144
        role: str,
        parts: Iterable[ChatCompletionContentPartParam],
    ) -> ChatMessageParseResult:
145
        texts: List[str] = []
146
        mm_futures: List[Awaitable[MultiModalDataDict]] = []
147

148
        for part in parts:
149
150
151
            part_type = part["type"]
            if part_type == "text":
                text = cast(ChatCompletionContentPartTextParam, part)["text"]
152
                texts.append(text)
153
            elif part_type == "image_url":
154
155
156
157
158
                if len(mm_futures) > 0:
                    raise NotImplementedError(
                        "Multiple 'image_url' input is currently not supported."
                    )

159
160
                image_url = cast(ChatCompletionContentPartImageParam,
                                 part)["image_url"]
161

162
163
164
165
                if image_url.get("detail", "auto") != "auto":
                    logger.warning(
                        "'image_url.detail' is currently not supported and "
                        "will be ignored.")
166

167
168
                image_future = async_get_and_parse_image(image_url["url"])
                mm_futures.append(image_future)
169
            else:
170
171
                raise NotImplementedError(f"Unknown part type: {part_type}")

172
        text_prompt = "\n".join(texts)
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        if mm_futures:
            image_token_str = self.image_token_str
            if image_token_str is not None:
                if image_token_str in text_prompt:
                    logger.warning(
                        "Detected image token string in the text prompt. "
                        "Skipping prompt formatting.")
                else:
                    text_prompt = self._get_full_image_text_prompt(
                        image_token_str=image_token_str,
                        text_prompt=text_prompt,
                    )

187
        messages = [ConversationMessage(role=role, content=text_prompt)]
188

189
        return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
190
191
192
193
194
195
196
197
198

    def _parse_chat_message_content(
        self,
        message: ChatCompletionMessageParam,
    ) -> ChatMessageParseResult:
        role = message["role"]
        content = message.get("content")

        if content is None:
199
            return ChatMessageParseResult(messages=[], mm_futures=[])
200
201
        if isinstance(content, str):
            messages = [ConversationMessage(role=role, content=content)]
202
            return ChatMessageParseResult(messages=messages, mm_futures=[])
203

204
        return self._parse_chat_message_content_parts(role, content)
205

206
    async def create_chat_completion(
207
208
209
        self,
        request: ChatCompletionRequest,
        raw_request: Optional[Request] = None
210
211
212
213
    ) -> Union[ErrorResponse, AsyncGenerator[str, None],
               ChatCompletionResponse]:
        """Completion API similar to OpenAI's API.

214
215
216
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
217

218
        NOTE: Currently we do not support the following feature:
219
220
221
222
223
224
225
            - function_call (Users should implement this by themselves)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        try:
226
            conversation: List[ConversationMessage] = []
227
            mm_futures: List[Awaitable[MultiModalDataDict]] = []
228

229
            for msg in request.messages:
230
                chat_parsed_result = self._parse_chat_message_content(msg)
231

232
                conversation.extend(chat_parsed_result.messages)
233
                mm_futures.extend(chat_parsed_result.mm_futures)
234

235
236
237
238
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

239
            prompt = self.tokenizer.apply_chat_template(
240
                conversation=conversation,
241
                tokenize=False,
242
                add_generation_prompt=request.add_generation_prompt,
243
244
245
246
                tools=tool_dicts,
                documents=request.documents,
                chat_template=request.chat_template,
                **(request.chat_template_kwargs or {}),
247
            )
248
        except Exception as e:
249
            logger.error("Error in applying chat template from request: %s", e)
250
251
            return self.create_error_response(str(e))

252
        mm_data: Optional[MultiModalDataDict] = None
253
        try:
254
255
            if len(mm_futures):
                # since we support only single mm data currently
256
257
258
                assert len(
                    mm_futures
                ) == 1, "Multiple 'image_url' input is currently not supported."
259
                mm_data = await mm_futures[0]
260
        except Exception as e:
261
            logger.error("Error in loading multi-modal data: %s", e)
262
263
            return self.create_error_response(str(e))

264
265
        request_id = f"cmpl-{random_uuid()}"
        try:
266
267
            # Tokenize/detokenize depending on prompt format (string/token list)
            prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
268
269
270
                request,
                prompt=prompt,
                add_special_tokens=request.add_special_tokens)
271
            sampling_params = request.to_sampling_params()
272
            lora_request = self._maybe_get_lora(request)
273
            decoding_config = await self.engine.get_decoding_config()
274
275
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
276
277
            guided_decode_logits_processor = (
                await get_guided_decoding_logits_processor(
278
279
                    guided_decoding_backend, request, await
                    self.engine.get_tokenizer()))
280
281
282
283
284
            if guided_decode_logits_processor:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logits_processor)
285
286
287
        except ValueError as e:
            return self.create_error_response(str(e))

288
289
290
291
        inputs: PromptInputs = {
            "prompt": prompt_text,
            "prompt_token_ids": prompt_ids,
        }
292
        if mm_data:
293
            inputs["multi_modal_data"] = mm_data
294

295
296
297
298
299
300
301
302
        is_tracing_enabled = await self.engine.is_tracing_enabled()
        trace_headers = None
        if is_tracing_enabled and raw_request:
            trace_headers = extract_trace_headers(raw_request.headers)
        if not is_tracing_enabled and raw_request and contains_trace_headers(
                raw_request.headers):
            log_tracing_disabled_warning()

303
        result_generator = self.engine.generate(
304
            inputs,
305
306
307
            sampling_params,
            request_id,
            lora_request,
308
            trace_headers=trace_headers,
309
        )
310
311
312
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
313
                request, result_generator, request_id, conversation)
314
        else:
315
316
            try:
                return await self.chat_completion_full_generator(
317
318
                    request, raw_request, result_generator, request_id,
                    conversation)
319
320
321
            except ValueError as e:
                # TODO: Use a vllm-specific Validation Error
                return self.create_error_response(str(e))
322
323
324
325
326

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
327
            return request.messages[-1]["role"]
328
329
330

    async def chat_completion_stream_generator(
            self, request: ChatCompletionRequest,
331
332
333
            result_generator: AsyncIterator[RequestOutput], request_id: str,
            conversation: List[ConversationMessage]
    ) -> AsyncGenerator[str, None]:
334
        model_name = self.served_model_names[0]
335
        created_time = int(time.time())
336
        chunk_object_type = "chat.completion.chunk"
337
        first_iteration = True
338
339

        # Send response for each token for each request.n (index)
340
        assert request.n is not None
341
342
343
        previous_texts = [""] * request.n
        previous_num_tokens = [0] * request.n
        finish_reason_sent = [False] * request.n
344
345
346
347
348
349
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
350
351
                    # Send first response for each request.n (index) with
                    # the role
352
353
354
355
356
357
358
359
360
361
362
363
364
                    role = self.get_chat_request_role(request)
                    for i in range(request.n):
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
                            delta=DeltaMessage(role=role),
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
365
366
367
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
368
369
370
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

371
372
                    # Send response to echo the input portion of the
                    # last message
373
374
                    if request.echo:
                        last_msg_content = ""
375
376
377
378
                        if conversation and conversation[-1].get(
                                "content") and conversation[-1].get(
                                    "role") == role:
                            last_msg_content = conversation[-1]["content"]
379
380
381

                        if last_msg_content:
                            for i in range(request.n):
382
383
384
385
386
387
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
                                        finish_reason=None))
388
389
390
391
392
393
394
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    logprobs=None,
                                    model=model_name)
395
396
397
                                if (request.stream_options and
                                        request.stream_options.include_usage):
                                    chunk.usage = None
398
399
400
401
402
403
404
405
406
407
408
409
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
410
                    out_logprobs = output.logprobs[
411
412
                        previous_num_tokens[i]:] if output.logprobs else None

413
414
415
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
416
                        logprobs = self._create_chat_logprobs(
417
                            token_ids=delta_token_ids,
418
                            top_logprobs=out_logprobs,
419
                            num_output_top_logprobs=request.top_logprobs,
420
421
422
423
424
425
426
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
427
428
429
430
431
432
433
434
435
436
437
438

                    if request.tool_choice and type(
                            request.tool_choice
                    ) is ChatCompletionNamedToolChoiceParam:
                        delta_message = DeltaMessage(tool_calls=[
                            ToolCall(function=FunctionCall(
                                name=request.tool_choice.function.name,
                                arguments=delta_text))
                        ])
                    else:
                        delta_message = DeltaMessage(content=delta_text)

439
440
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
441

442
443
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
444
                            delta=delta_message,
445
446
447
448
449
450
451
452
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
453
454
455
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
456
457
458
459
460
461
462
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
                    else:
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
463
                            delta=delta_message,
464
                            logprobs=logprobs,
465
466
                            finish_reason=output.finish_reason,
                            stop_reason=output.stop_reason)
467
468
469
470
471
472
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
473
474
475
476
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
                        data = chunk.model_dump_json(exclude_unset=True)
477
478
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
479

480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
498

499
500
501
502
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
503
504
505
506
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
507
        self, request: ChatCompletionRequest, raw_request: Optional[Request],
508
509
510
        result_generator: AsyncIterator[RequestOutput], request_id: str,
        conversation: List[ConversationMessage]
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
511

512
        model_name = self.served_model_names[0]
513
        created_time = int(time.time())
514
        final_res: Optional[RequestOutput] = None
515
516

        async for res in result_generator:
517
            if raw_request is not None and await raw_request.is_disconnected():
518
519
520
521
522
523
                # Abort the request if the client disconnects.
                await self.engine.abort(request_id)
                return self.create_error_response("Client disconnected")
            final_res = res
        assert final_res is not None

524
        choices: List[ChatCompletionResponseChoice] = []
525

526
527
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
528
            token_ids = output.token_ids
529
            out_logprobs = output.logprobs
530

531
532
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
533
                logprobs = self._create_chat_logprobs(
534
                    token_ids=token_ids,
535
                    top_logprobs=out_logprobs,
536
                    num_output_top_logprobs=request.top_logprobs,
537
538
539
540
                )
            else:
                logprobs = None

541
542
543
544
545
546
547
548
549
550
551
552
553
            if request.tool_choice and type(
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
            elif not request.tool_choice or request.tool_choice == "none":
                message = ChatMessage(role=role, content=output.text)

554
555
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
556
                message=message,
557
                logprobs=logprobs,
558
                finish_reason=output.finish_reason,
559
                stop_reason=output.stop_reason)
560
561
562
563
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
564
565
566
            if conversation and conversation[-1].get(
                    "content") and conversation[-1].get("role") == role:
                last_msg_content = conversation[-1]["content"]
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587

            for choice in choices:
                full_message = last_msg_content + choice.message.content
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )

588
        return response
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

    def _get_top_logprobs(
            self, logprobs: Dict[int, Logprob],
            top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
        return [
            ChatCompletionLogProb(
                token=self._get_decoded_token(p[1], p[0]),
                logprob=max(p[1].logprob, -9999.0),
                bytes=list(
                    self._get_decoded_token(p[1],
                                            p[0]).encode("utf-8",
                                                         errors="replace")))
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""

        logprobs_content = []

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
                        token=self.tokenizer.decode(token_id),
                        bytes=list(
                            self.tokenizer.decode(token_id).encode(
                                "utf-8", errors="replace"))))
            else:
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
                        token=step_top_logprobs[token_id].decoded_token,
                        logprob=max(step_top_logprobs[token_id].logprob,
                                    -9999.0),
                        bytes=list(
                            step_top_logprobs[token_id].decoded_token.encode(
                                "utf-8", errors="replace")),
                        top_logprobs=self._get_top_logprobs(
                            step_top_logprobs, num_output_top_logprobs)))

        return ChatCompletionLogProbs(content=logprobs_content)