serving_chat.py 21.5 KB
Newer Older
1
import time
2
3
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
                    Optional)
4
from typing import Sequence as GenericSequence
5
from typing import Union
6

7
from fastapi import Request
8
from transformers import PreTrainedTokenizer
9

10
from vllm.config import ModelConfig
11
from vllm.engine.async_llm_engine import AsyncLLMEngine
12
13
14
from vllm.entrypoints.openai.chat_utils import (ConversationMessage,
                                                load_chat_template,
                                                parse_chat_message_content)
15
from vllm.entrypoints.openai.protocol import (
16
17
    ChatCompletionLogProb, ChatCompletionLogProbs,
    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
18
    ChatCompletionRequest, ChatCompletionResponse,
19
20
    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
21
    FunctionCall, ToolCall, UsageInfo)
22
23
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
24
from vllm.inputs import PromptInputs
25
from vllm.logger import init_logger
26
27
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
28
from vllm.multimodal import MultiModalDataDict
29
from vllm.outputs import RequestOutput
30
from vllm.sequence import Logprob
31
32
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
33
from vllm.utils import random_uuid
34
35
36
37
38
39
40
41

logger = init_logger(__name__)


class OpenAIServingChat(OpenAIServing):

    def __init__(self,
                 engine: AsyncLLMEngine,
42
                 model_config: ModelConfig,
43
                 served_model_names: List[str],
44
                 response_role: str,
45
46
                 lora_modules: Optional[List[LoRAModulePath]] = None,
                 chat_template: Optional[str] = None):
47
        super().__init__(engine=engine,
48
                         model_config=model_config,
49
                         served_model_names=served_model_names,
50
                         lora_modules=lora_modules)
51

52
        self.response_role = response_role
53
54
55

        # If this is None we use the tokenizer's default chat template
        self.chat_template = load_chat_template(chat_template)
56

57
    async def create_chat_completion(
58
59
60
        self,
        request: ChatCompletionRequest,
        raw_request: Optional[Request] = None
61
62
63
64
    ) -> Union[ErrorResponse, AsyncGenerator[str, None],
               ChatCompletionResponse]:
        """Completion API similar to OpenAI's API.

65
66
67
        See https://platform.openai.com/docs/api-reference/chat/create
        for the API specification. This API mimics the OpenAI
        ChatCompletion API.
68

69
        NOTE: Currently we do not support the following feature:
70
71
72
73
74
75
76
            - function_call (Users should implement this by themselves)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        try:
77
78
79
            _, lora_request = self._maybe_get_adapter(request)
            tokenizer = await self.engine.get_tokenizer(lora_request)

80
            conversation: List[ConversationMessage] = []
81
            mm_futures: List[Awaitable[MultiModalDataDict]] = []
82

83
            for msg in request.messages:
84
85
                chat_parsed_result = parse_chat_message_content(
                    msg, self.model_config, tokenizer)
86

87
                conversation.extend(chat_parsed_result.messages)
88
                mm_futures.extend(chat_parsed_result.mm_futures)
89

90
91
92
93
            tool_dicts = None if request.tools is None else [
                tool.model_dump() for tool in request.tools
            ]

94
            prompt = tokenizer.apply_chat_template(
95
                conversation=conversation,
96
                tokenize=False,
97
                add_generation_prompt=request.add_generation_prompt,
98
99
                tools=tool_dicts,
                documents=request.documents,
100
                chat_template=request.chat_template or self.chat_template,
101
                **(request.chat_template_kwargs or {}),
102
            )
103
        except Exception as e:
104
            logger.error("Error in applying chat template from request: %s", e)
105
106
            return self.create_error_response(str(e))

107
        mm_data: Optional[MultiModalDataDict] = None
108
        try:
109
110
            if len(mm_futures):
                # since we support only single mm data currently
111
112
113
                assert len(
                    mm_futures
                ) == 1, "Multiple 'image_url' input is currently not supported."
114
                mm_data = await mm_futures[0]
115
        except Exception as e:
116
            logger.error("Error in loading multi-modal data: %s", e)
117
118
            return self.create_error_response(str(e))

119
120
        request_id = f"cmpl-{random_uuid()}"
        try:
121
            # Tokenize/detokenize depending on prompt format (string/token list)
122
            prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
123
                request,
124
                tokenizer,
125
126
                prompt=prompt,
                add_special_tokens=request.add_special_tokens)
127
            sampling_params = request.to_sampling_params()
128
            decoding_config = await self.engine.get_decoding_config()
129
130
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
131
            guided_decode_logits_processor = (
132
133
134
                await
                get_guided_decoding_logits_processor(guided_decoding_backend,
                                                     request, tokenizer))
135
136
137
138
139
            if guided_decode_logits_processor:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logits_processor)
140
141
142
        except ValueError as e:
            return self.create_error_response(str(e))

143
144
145
146
        inputs: PromptInputs = {
            "prompt": prompt_text,
            "prompt_token_ids": prompt_ids,
        }
147
        if mm_data:
148
            inputs["multi_modal_data"] = mm_data
149

150
151
152
153
154
155
156
157
        is_tracing_enabled = await self.engine.is_tracing_enabled()
        trace_headers = None
        if is_tracing_enabled and raw_request:
            trace_headers = extract_trace_headers(raw_request.headers)
        if not is_tracing_enabled and raw_request and contains_trace_headers(
                raw_request.headers):
            log_tracing_disabled_warning()

158
        result_generator = self.engine.generate(
159
            inputs,
160
161
162
            sampling_params,
            request_id,
            lora_request,
163
            trace_headers=trace_headers,
164
        )
165
166
167
        # Streaming response
        if request.stream:
            return self.chat_completion_stream_generator(
168
                request, result_generator, request_id, conversation, tokenizer)
169
        else:
170
171
            try:
                return await self.chat_completion_full_generator(
172
                    request, raw_request, result_generator, request_id,
173
                    conversation, tokenizer)
174
175
176
            except ValueError as e:
                # TODO: Use a vllm-specific Validation Error
                return self.create_error_response(str(e))
177
178
179
180
181

    def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
        if request.add_generation_prompt:
            return self.response_role
        else:
182
            return request.messages[-1]["role"]
183
184

    async def chat_completion_stream_generator(
185
186
187
188
189
190
        self,
        request: ChatCompletionRequest,
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
        tokenizer: PreTrainedTokenizer,
191
    ) -> AsyncGenerator[str, None]:
192
        model_name = self.served_model_names[0]
193
        created_time = int(time.time())
194
        chunk_object_type = "chat.completion.chunk"
195
        first_iteration = True
196
197

        # Send response for each token for each request.n (index)
198
        assert request.n is not None
199
200
201
        previous_texts = [""] * request.n
        previous_num_tokens = [0] * request.n
        finish_reason_sent = [False] * request.n
202
203
204
205
206
207
        try:
            async for res in result_generator:
                # We need to do it here, because if there are exceptions in
                # the result_generator, it needs to be sent as the FIRST
                # response (by the try...catch).
                if first_iteration:
208
209
                    # Send first response for each request.n (index) with
                    # the role
210
211
212
213
214
215
216
217
218
219
220
221
222
                    role = self.get_chat_request_role(request)
                    for i in range(request.n):
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
                            delta=DeltaMessage(role=role),
                            logprobs=None,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
223
224
225
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
226
227
228
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"

229
230
                    # Send response to echo the input portion of the
                    # last message
231
232
                    if request.echo:
                        last_msg_content = ""
233
234
235
236
                        if conversation and conversation[-1].get(
                                "content") and conversation[-1].get(
                                    "role") == role:
                            last_msg_content = conversation[-1]["content"]
237
238
239

                        if last_msg_content:
                            for i in range(request.n):
240
241
242
243
244
245
                                choice_data = (
                                    ChatCompletionResponseStreamChoice(
                                        index=i,
                                        delta=DeltaMessage(
                                            content=last_msg_content),
                                        finish_reason=None))
246
247
248
249
250
251
252
                                chunk = ChatCompletionStreamResponse(
                                    id=request_id,
                                    object=chunk_object_type,
                                    created=created_time,
                                    choices=[choice_data],
                                    logprobs=None,
                                    model=model_name)
253
254
255
                                if (request.stream_options and
                                        request.stream_options.include_usage):
                                    chunk.usage = None
256
257
258
259
260
261
262
263
264
265
266
267
                                data = chunk.model_dump_json(
                                    exclude_unset=True)
                                yield f"data: {data}\n\n"
                    first_iteration = False

                for output in res.outputs:
                    i = output.index

                    if finish_reason_sent[i]:
                        continue

                    delta_token_ids = output.token_ids[previous_num_tokens[i]:]
268
                    out_logprobs = output.logprobs[
269
270
                        previous_num_tokens[i]:] if output.logprobs else None

271
272
273
                    if request.logprobs and request.top_logprobs is not None:
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
274
                        logprobs = self._create_chat_logprobs(
275
                            token_ids=delta_token_ids,
276
                            top_logprobs=out_logprobs,
277
                            tokenizer=tokenizer,
278
                            num_output_top_logprobs=request.top_logprobs,
279
280
281
282
283
284
285
                        )
                    else:
                        logprobs = None

                    delta_text = output.text[len(previous_texts[i]):]
                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
286
287
288
289
290
291
292
293
294
295
296
297

                    if request.tool_choice and type(
                            request.tool_choice
                    ) is ChatCompletionNamedToolChoiceParam:
                        delta_message = DeltaMessage(tool_calls=[
                            ToolCall(function=FunctionCall(
                                name=request.tool_choice.function.name,
                                arguments=delta_text))
                        ])
                    else:
                        delta_message = DeltaMessage(content=delta_text)

298
299
                    if output.finish_reason is None:
                        # Send token-by-token response for each request.n
300

301
302
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
303
                            delta=delta_message,
304
305
306
307
308
309
310
311
                            logprobs=logprobs,
                            finish_reason=None)
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
312
313
314
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
315
316
317
318
319
320
321
                        data = chunk.model_dump_json(exclude_unset=True)
                        yield f"data: {data}\n\n"
                    else:
                        # Send the finish response for each request.n only once
                        prompt_tokens = len(res.prompt_token_ids)
                        choice_data = ChatCompletionResponseStreamChoice(
                            index=i,
322
                            delta=delta_message,
323
                            logprobs=logprobs,
324
325
                            finish_reason=output.finish_reason,
                            stop_reason=output.stop_reason)
326
327
328
329
330
331
                        chunk = ChatCompletionStreamResponse(
                            id=request_id,
                            object=chunk_object_type,
                            created=created_time,
                            choices=[choice_data],
                            model=model_name)
332
333
334
335
                        if (request.stream_options
                                and request.stream_options.include_usage):
                            chunk.usage = None
                        data = chunk.model_dump_json(exclude_unset=True)
336
337
                        yield f"data: {data}\n\n"
                        finish_reason_sent[i] = True
338

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage = UsageInfo(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=previous_num_tokens[i],
                    total_tokens=prompt_tokens + previous_num_tokens[i],
                )

                final_usage_chunk = ChatCompletionStreamResponse(
                    id=request_id,
                    object=chunk_object_type,
                    created=created_time,
                    choices=[],
                    model=model_name,
                    usage=final_usage)
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"
357

358
359
360
361
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
362
363
364
365
        # Send the final done message after all response.n are finished
        yield "data: [DONE]\n\n"

    async def chat_completion_full_generator(
366
367
368
369
370
371
372
        self,
        request: ChatCompletionRequest,
        raw_request: Optional[Request],
        result_generator: AsyncIterator[RequestOutput],
        request_id: str,
        conversation: List[ConversationMessage],
        tokenizer: PreTrainedTokenizer,
373
    ) -> Union[ErrorResponse, ChatCompletionResponse]:
374

375
        model_name = self.served_model_names[0]
376
        created_time = int(time.time())
377
        final_res: Optional[RequestOutput] = None
378
379

        async for res in result_generator:
380
            if raw_request is not None and await raw_request.is_disconnected():
381
382
383
384
385
386
                # Abort the request if the client disconnects.
                await self.engine.abort(request_id)
                return self.create_error_response("Client disconnected")
            final_res = res
        assert final_res is not None

387
        choices: List[ChatCompletionResponseChoice] = []
388

389
390
        role = self.get_chat_request_role(request)
        for output in final_res.outputs:
391
            token_ids = output.token_ids
392
            out_logprobs = output.logprobs
393

394
395
            if request.logprobs and request.top_logprobs is not None:
                assert out_logprobs is not None, "Did not output logprobs"
396
                logprobs = self._create_chat_logprobs(
397
                    token_ids=token_ids,
398
                    top_logprobs=out_logprobs,
399
                    num_output_top_logprobs=request.top_logprobs,
400
                    tokenizer=tokenizer,
401
402
403
404
                )
            else:
                logprobs = None

405
406
407
408
409
410
411
412
413
414
415
416
417
            if request.tool_choice and type(
                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
                message = ChatMessage(
                    role=role,
                    content="",
                    tool_calls=[
                        ToolCall(function=FunctionCall(
                            name=request.tool_choice.function.name,
                            arguments=output.text))
                    ])
            elif not request.tool_choice or request.tool_choice == "none":
                message = ChatMessage(role=role, content=output.text)

418
419
            choice_data = ChatCompletionResponseChoice(
                index=output.index,
420
                message=message,
421
                logprobs=logprobs,
422
                finish_reason=output.finish_reason,
423
                stop_reason=output.stop_reason)
424
425
426
427
            choices.append(choice_data)

        if request.echo:
            last_msg_content = ""
428
429
430
            if conversation and conversation[-1].get(
                    "content") and conversation[-1].get("role") == role:
                last_msg_content = conversation[-1]["content"]
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

            for choice in choices:
                full_message = last_msg_content + choice.message.content
                choice.message.content = full_message

        num_prompt_tokens = len(final_res.prompt_token_ids)
        num_generated_tokens = sum(
            len(output.token_ids) for output in final_res.outputs)
        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )
        response = ChatCompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )

452
        return response
453
454

    def _get_top_logprobs(
455
456
            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
            tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
457
458
        return [
            ChatCompletionLogProb(
459
460
                token=(token := self._get_decoded_token(p[1], p[0],
                                                        tokenizer)),
461
                logprob=max(p[1].logprob, -9999.0),
462
                bytes=list(token.encode("utf-8", errors="replace")))
463
464
465
466
467
468
469
470
            for i, p in enumerate(logprobs.items())
            if top_logprobs and i < top_logprobs
        ]

    def _create_chat_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
471
        tokenizer: PreTrainedTokenizer,
472
473
474
475
476
477
478
479
480
        num_output_top_logprobs: Optional[int] = None,
    ) -> ChatCompletionLogProbs:
        """Create OpenAI-style logprobs."""

        logprobs_content = []

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
481
                token = tokenizer.decode(token_id)
482
483
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
484
485
                        token=token,
                        bytes=list(token.encode("utf-8", errors="replace"))))
486
487
488
489
490
491
492
493
494
495
            else:
                logprobs_content.append(
                    ChatCompletionLogProbsContent(
                        token=step_top_logprobs[token_id].decoded_token,
                        logprob=max(step_top_logprobs[token_id].logprob,
                                    -9999.0),
                        bytes=list(
                            step_top_logprobs[token_id].decoded_token.encode(
                                "utf-8", errors="replace")),
                        top_logprobs=self._get_top_logprobs(
496
497
                            step_top_logprobs, num_output_top_logprobs,
                            tokenizer)))
498
499

        return ChatCompletionLogProbs(content=logprobs_content)