serving_chat.py 26.5 KB
Newer Older
1
import codecs
2
import time
3
from dataclasses import dataclass, field
4
from functools import cached_property
5
6
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
                    List, Optional)
7
8
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
9

10
from fastapi import Request
11
12
from openai.types.chat import (ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartTextParam)
13

14
from vllm.config import ModelConfig
15
16
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
17
18
    ChatCompletionContentPartParam, ChatCompletionLogProb,
    ChatCompletionLogProbs, ChatCompletionLogProbsContent,
19
20
    ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
    ChatCompletionRequest, ChatCompletionResponse,
21
22
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
23
    FunctionCall, ToolCall, UsageInfo)
24
25
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
26
from vllm.inputs import PromptInputs
27
from vllm.logger import init_logger
28
29
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
30
from vllm.multimodal import MultiModalDataDict
31
from vllm.multimodal.utils import async_get_and_parse_image
32
from vllm.outputs import RequestOutput
33
from vllm.sequence import Logprob
34
35
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
36
from vllm.utils import random_uuid
37
38
39
40

logger = init_logger(__name__)


41
42
43
44
45
46
@final  # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
    role: str
    content: str


47
48
49
@dataclass(frozen=True)
class ChatMessageParseResult:
    messages: List[ConversationMessage]
50
    mm_futures: List[Awaitable[MultiModalDataDict]] = field(
51
        default_factory=list)
52
53


54
55
56
57
class OpenAIServingChat(OpenAIServing):

    def __init__(self,
                 engine: AsyncLLMEngine,
58
                 model_config: ModelConfig,
59
                 served_model_names: List[str],
60
                 response_role: str,
61
62
                 lora_modules: Optional[List[LoRAModulePath]] = None,
                 chat_template: Optional[str] = None):
63
        super().__init__(engine=engine,
64
                         model_config=model_config,
65
                         served_model_names=served_model_names,
66
                         lora_modules=lora_modules)
67

68
        self.response_role = response_role
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        self._load_chat_template(chat_template)

    def _load_chat_template(self, chat_template: Optional[str]):
        tokenizer = self.tokenizer

        if chat_template is not None:
            try:
                with open(chat_template, "r") as f:
                    tokenizer.chat_template = f.read()
            except OSError as e:
                JINJA_CHARS = "{}\n"
                if not any(c in chat_template for c in JINJA_CHARS):
                    msg = (f"The supplied chat template ({chat_template}) "
                           f"looks like a file path, but it failed to be "
                           f"opened. Reason: {e}")
                    raise ValueError(msg) from e

                # If opening a file fails, set chat template to be args to
                # ensure we decode so our escape are interpreted correctly
                tokenizer.chat_template = codecs.decode(
                    chat_template, "unicode_escape")

            logger.info("Using supplied chat template:\n%s",
                        tokenizer.chat_template)
        elif tokenizer.chat_template is not None:
            logger.info("Using default chat template:\n%s",
                        tokenizer.chat_template)
        else:
            logger.warning(
                "No chat template provided. Chat API will not work.")
99

100
101
102
103
104
105
106
107
108
109
110
111
    @cached_property
    def image_token_str(self) -> Optional[str]:
        # TODO: Let user specify how to insert image tokens into prompt
        # (similar to chat template)
        model_type = self.model_config.hf_config.model_type
        if model_type == "phi3_v":
            # Workaround since this token is not defined in the tokenizer
            return "<|image_1|>"
        if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
                          "paligemma"):
            # These models do not use image tokens in the prompt
            return None
112
113
114
        if model_type.startswith("llava"):
            return self.tokenizer.decode(
                self.model_config.hf_config.image_token_index)
115

116
117
        else:
            raise TypeError("Unknown model type: {model_type}")
118

119
120
121
122
123
124
125
126
127
128
    # TODO: Let user specify how to insert image tokens into prompt
    # (similar to chat template)
    def _get_full_image_text_prompt(self, image_token_str: str,
                                    text_prompt: str) -> str:
        """Combine image and text prompts for vision language model"""

        # NOTE: For now we assume all model architectures use the same
        # image + text prompt format. This may change in the future.
        return f"{image_token_str}\n{text_prompt}"

129
    def _parse_chat_message_content_parts(
130
        self,
131
132
133
        role: str,
        parts: Iterable[ChatCompletionContentPartParam],
    ) -> ChatMessageParseResult:
134
        texts: List[str] = []
135
        mm_futures: List[Awaitable[MultiModalDataDict]] = []
136

137
        for part in parts:
138
139
140
            part_type = part["type"]
            if part_type == "text":
                text = cast(ChatCompletionContentPartTextParam, part)["text"]
141
                texts.append(text)
142
            elif part_type == "image_url":
143
144
145
146
147
                if len(mm_futures) > 0:
                    raise NotImplementedError(
                        "Multiple 'image_url' input is currently not supported."
                    )

148
149
                image_url = cast(ChatCompletionContentPartImageParam,
                                 part)["image_url"]
150

151
152
153
154
                if image_url.get("detail", "auto") != "auto":
                    logger.warning(
                        "'image_url.detail' is currently not supported and "
                        "will be ignored.")
155

156
157
                image_future = async_get_and_parse_image(image_url["url"])
                mm_futures.append(image_future)
158
            else:
159
160
                raise NotImplementedError(f"Unknown part type: {part_type}")

161
        text_prompt = "\n".join(texts)
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        if mm_futures:
            image_token_str = self.image_token_str
            if image_token_str is not None:
                if image_token_str in text_prompt:
                    logger.warning(
                        "Detected image token string in the text prompt. "
                        "Skipping prompt formatting.")
                else:
                    text_prompt = self._get_full_image_text_prompt(
                        image_token_str=image_token_str,
                        text_prompt=text_prompt,
                    )

176
        messages = [ConversationMessage(role=role, content=text_prompt)]
177

178
        return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
179
180
181
182
183
184
185
186
187

    def _parse_chat_message_content(
        self,
        message: ChatCompletionMessageParam,
    ) -> ChatMessageParseResult:
        role = message["role"]
        content = message.get("content")

        if content is None:
188
            return ChatMessageParseResult(messages=[], mm_futures=[])
189
190
        if isinstance(content, str):
            messages = [ConversationMessage(role=role, content=content)]
191
            return ChatMessageParseResult(messages=messages, mm_futures=[])
192

193
        return self._parse_chat_message_content_parts(role, content)
194

195
    async def create_chat_completion(
196
197
198
        self,
        request: ChatCompletionRequest,
        raw_request: Optional[Request] = None
199
200
201
202
    ) -> Union[ErrorResponse, AsyncGenerator[str, None],
               ChatCompletionResponse]:
        """Completion API similar to OpenAI's API.

203
204
205
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
206

207
        NOTE: Currently we do not support the following feature:
208
209
210
211
212
213
214
            - function_call (Users should implement this by themselves)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        try:
215
            conversation: List[ConversationMessage] = []
216
            mm_futures: List[Awaitable[MultiModalDataDict]] = []
217

218
            for msg in request.messages:
219
                chat_parsed_result = self._parse_chat_message_content(msg)
220

221
                conversation.extend(chat_parsed_result.messages)
222
                mm_futures.extend(chat_parsed_result.mm_futures)
223

224
225
226
227
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

228
            prompt = self.tokenizer.apply_chat_template(
229
                conversation=conversation,
230
                tokenize=False,
231
                add_generation_prompt=request.add_generation_prompt,
232
233
234
235
                tools=tool_dicts,
                documents=request.documents,
                chat_template=request.chat_template,
                **(request.chat_template_kwargs or {}),
236
            )
237
        except Exception as e:
238
            logger.error("Error in applying chat template from request: %s", e)
239
240
            return self.create_error_response(str(e))

241
        mm_data: Optional[MultiModalDataDict] = None
242
        try:
243
244
            if len(mm_futures):
                # since we support only single mm data currently
245
246
247
                assert len(
                    mm_futures
                ) == 1, "Multiple 'image_url' input is currently not supported."
248
                mm_data = await mm_futures[0]
249
        except Exception as e:
250
            logger.error("Error in loading multi-modal data: %s", e)
251
252
            return self.create_error_response(str(e))

253
254
        request_id = f"cmpl-{random_uuid()}"
        try:
255
256
            # Tokenize/detokenize depending on prompt format (string/token list)
            prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
257
258
259
                request,
                prompt=prompt,
                add_special_tokens=request.add_special_tokens)
260
            sampling_params = request.to_sampling_params()
261
            lora_request = self._maybe_get_lora(request)
262
            decoding_config = await self.engine.get_decoding_config()
263
264
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
265
266
            guided_decode_logits_processor = (
                await get_guided_decoding_logits_processor(
267
268
                    guided_decoding_backend, request, await
                    self.engine.get_tokenizer()))
269
270
271
272
273
            if guided_decode_logits_processor:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logits_processor)
274
275
276
        except ValueError as e:
            return self.create_error_response(str(e))

277
278
279
280
        inputs: PromptInputs = {
            "prompt": prompt_text,
            "prompt_token_ids": prompt_ids,
        }
281
        if mm_data:
282
            inputs["multi_modal_data"] = mm_data
283

284
285
286
287
288
289
290
291
        is_tracing_enabled = await self.engine.is_tracing_enabled()
        trace_headers = None
        if is_tracing_enabled and raw_request:
            trace_headers = extract_trace_headers(raw_request.headers)
        if not is_tracing_enabled and raw_request and contains_trace_headers(
                raw_request.headers):
            log_tracing_disabled_warning()

292
        result_generator = self.engine.generate(
293
            inputs,
294
295
296
            sampling_params,
            request_id,
            lora_request,
297
            trace_headers=trace_headers,
298
        )
299
300
301
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
302
                request, result_generator, request_id, conversation)
303
        else:
304
305
            try:
                return await self.chat_completion_full_generator(
306
307
                    request, raw_request, result_generator, request_id,
                    conversation)
308
309
310
            except ValueError as e:
                # TODO: Use a vllm-specific Validation Error
                return self.create_error_response(str(e))
311
312
313
314
315

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
316
            return request.messages[-1]["role"]
317
318
319

    async def chat_completion_stream_generator(
            self, request: ChatCompletionRequest,
320
321
322
            result_generator: AsyncIterator[RequestOutput], request_id: str,
            conversation: List[ConversationMessage]
    ) -> AsyncGenerator[str, None]:
323
        model_name = self.served_model_names[0]
324
        created_time = int(time.time())
325
        chunk_object_type = "chat.completion.chunk"
326
        first_iteration = True
327
328

        # Send response for each token for each request.n (index)
329
        assert request.n is not None
330
331
332
        previous_texts = [""] * request.n
        previous_num_tokens = [0] * request.n
        finish_reason_sent = [False] * request.n
333
334
335
336
337
338
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
339
340
                    # Send first response for each request.n (index) with
                    # the role
341
342
343
344
345
346
347
348
349
350
351
352
353
                    role = self.get_chat_request_role(request)
                    for i in range(request.n):
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
                            delta=DeltaMessage(role=role),
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
354
355
356
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
357
358
359
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

360
361
                    # Send response to echo the input portion of the
                    # last message
362
363
                    if request.echo:
                        last_msg_content = ""
364
365
366
367
                        if conversation and conversation[-1].get(
                                "content") and conversation[-1].get(
                                    "role") == role:
                            last_msg_content = conversation[-1]["content"]
368
369
370

                        if last_msg_content:
                            for i in range(request.n):
371
372
373
374
375
376
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
                                        finish_reason=None))
377
378
379
380
381
382
383
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    logprobs=None,
                                    model=model_name)
384
385
386
                                if (request.stream_options and
                                        request.stream_options.include_usage):
                                    chunk.usage = None
387
388
389
390
391
392
393
394
395
396
397
398
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
399
                    out_logprobs = output.logprobs[
400
401
                        previous_num_tokens[i]:] if output.logprobs else None

402
403
404
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
405
                        logprobs = self._create_chat_logprobs(
406
                            token_ids=delta_token_ids,
407
                            top_logprobs=out_logprobs,
408
                            num_output_top_logprobs=request.top_logprobs,
409
410
411
412
413
414
415
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
416
417
418
419
420
421
422
423
424
425
426
427

                    if request.tool_choice and type(
                            request.tool_choice
                    ) is ChatCompletionNamedToolChoiceParam:
                        delta_message = DeltaMessage(tool_calls=[
                            ToolCall(function=FunctionCall(
                                name=request.tool_choice.function.name,
                                arguments=delta_text))
                        ])
                    else:
                        delta_message = DeltaMessage(content=delta_text)

428
429
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
430

431
432
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
433
                            delta=delta_message,
434
435
436
437
438
439
440
441
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
442
443
444
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
445
446
447
448
449
450
451
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
                    else:
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
452
                            delta=delta_message,
453
                            logprobs=logprobs,
454
455
                            finish_reason=output.finish_reason,
                            stop_reason=output.stop_reason)
456
457
458
459
460
461
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
462
463
464
465
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
                        data = chunk.model_dump_json(exclude_unset=True)
466
467
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
468

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
487

488
489
490
491
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
492
493
494
495
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
496
        self, request: ChatCompletionRequest, raw_request: Optional[Request],
497
498
499
        result_generator: AsyncIterator[RequestOutput], request_id: str,
        conversation: List[ConversationMessage]
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
500

501
        model_name = self.served_model_names[0]
502
        created_time = int(time.time())
503
        final_res: Optional[RequestOutput] = None
504
505

        async for res in result_generator:
506
            if raw_request is not None and await raw_request.is_disconnected():
507
508
509
510
511
512
                # Abort the request if the client disconnects.
                await self.engine.abort(request_id)
                return self.create_error_response("Client disconnected")
            final_res = res
        assert final_res is not None

513
        choices: List[ChatCompletionResponseChoice] = []
514

515
516
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
517
            token_ids = output.token_ids
518
            out_logprobs = output.logprobs
519

520
521
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
522
                logprobs = self._create_chat_logprobs(
523
                    token_ids=token_ids,
524
                    top_logprobs=out_logprobs,
525
                    num_output_top_logprobs=request.top_logprobs,
526
527
528
529
                )
            else:
                logprobs = None

530
531
532
533
534
535
536
537
538
539
540
541
542
            if request.tool_choice and type(
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
            elif not request.tool_choice or request.tool_choice == "none":
                message = ChatMessage(role=role, content=output.text)

543
544
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
545
                message=message,
546
                logprobs=logprobs,
547
                finish_reason=output.finish_reason,
548
                stop_reason=output.stop_reason)
549
550
551
552
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
553
554
555
            if conversation and conversation[-1].get(
                    "content") and conversation[-1].get("role") == role:
                last_msg_content = conversation[-1]["content"]
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

            for choice in choices:
                full_message = last_msg_content + choice.message.content
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )

577
        return response
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625

    def _get_top_logprobs(
            self, logprobs: Dict[int, Logprob],
            top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
        return [
            ChatCompletionLogProb(
                token=self._get_decoded_token(p[1], p[0]),
                logprob=max(p[1].logprob, -9999.0),
                bytes=list(
                    self._get_decoded_token(p[1],
                                            p[0]).encode("utf-8",
                                                         errors="replace")))
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""

        logprobs_content = []

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
                        token=self.tokenizer.decode(token_id),
                        bytes=list(
                            self.tokenizer.decode(token_id).encode(
                                "utf-8", errors="replace"))))
            else:
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
                        token=step_top_logprobs[token_id].decoded_token,
                        logprob=max(step_top_logprobs[token_id].logprob,
                                    -9999.0),
                        bytes=list(
                            step_top_logprobs[token_id].decoded_token.encode(
                                "utf-8", errors="replace")),
                        top_logprobs=self._get_top_logprobs(
                            step_top_logprobs, num_output_top_logprobs)))

        return ChatCompletionLogProbs(content=logprobs_content)