serving_completion.py 22.5 KB
Newer Older
1
import asyncio
2
import time
3
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
4
from typing import Sequence as GenericSequence
5
from typing import Tuple, Union, cast
6

7
from fastapi import Request
8

9
from vllm.config import ModelConfig
10
from vllm.engine.protocol import EngineClient
11
from vllm.entrypoints.logger import RequestLogger
12
# yapf conflicts with isort for this block
13
14
15
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
16
17
18
19
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
20
21
22
                                              ErrorResponse,
                                              RequestResponseMetadata,
                                              UsageInfo)
23
# yapf: enable
24
25
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
                                                    LoRAModulePath,
26
27
                                                    OpenAIServing,
                                                    PromptAdapterPath)
28
29
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
30
from vllm.sampling_params import BeamSearchParams, SamplingParams
31
from vllm.sequence import Logprob
32
from vllm.transformers_utils.tokenizer import AnyTokenizer
33
from vllm.utils import merge_async_iterators
34
35
36
37
38
39

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):

40
41
    def __init__(
        self,
42
        engine_client: EngineClient,
43
        model_config: ModelConfig,
44
        base_model_paths: List[BaseModelPath],
45
46
47
48
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        prompt_adapters: Optional[List[PromptAdapterPath]],
        request_logger: Optional[RequestLogger],
49
        return_tokens_as_token_ids: bool = False,
50
    ):
51
        super().__init__(engine_client=engine_client,
52
                         model_config=model_config,
53
                         base_model_paths=base_model_paths,
54
                         lora_modules=lora_modules,
55
                         prompt_adapters=prompt_adapters,
56
57
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
58
59
60
61
62
        diff_sampling_param = self.model_config.get_diff_sampling_param()
        if diff_sampling_param:
            logger.info(
                "Overwriting default completion sampling param with: %s",
                diff_sampling_param)
63

64
65
66
67
68
    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request,
    ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
69
70
71
72
73
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

74
        NOTE: Currently we do not support the following feature:
75
76
77
78
79
80
81
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

82
83
84
85
86
87
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

88
        # Return error for unsupported features.
89
90
91
92
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

93
        request_id = f"cmpl-{self._base_request_id(raw_request)}"
94
        created_time = int(time.time())
95

96
97
98
99
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

100
        try:
101
102
103
104
105
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

106
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
107

108
            request_prompts, engine_prompts = await self._preprocess_completion(
109
110
111
112
113
114
115
116
117
                request,
                tokenizer,
                request.prompt,
                truncate_prompt_tokens=request.truncate_prompt_tokens,
                add_special_tokens=request.add_special_tokens,
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
118

119
120
121
122
        # Schedule the request and get the result generator.
        generators: List[AsyncGenerator[RequestOutput, None]] = []
        try:
            for i, engine_prompt in enumerate(engine_prompts):
123
124
                sampling_params: Union[SamplingParams, BeamSearchParams]
                default_max_tokens = self.max_model_len - len(
125
                    engine_prompt["prompt_token_ids"])
126
127
128
                # Build default sampling params
                default_sampling_params = (
                    self.model_config.get_diff_sampling_param())
129
130
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
131
                        default_max_tokens, default_sampling_params)
132
133
                else:
                    sampling_params = request.to_sampling_params(
134
                        default_max_tokens,
135
136
                        self.model_config.logits_processor_pattern,
                        default_sampling_params)
137

138
139
140
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
141
                                 request_prompts[i],
142
143
144
                                 params=sampling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)
145

146
147
                trace_headers = (await
                                 self._get_trace_headers(raw_request.headers))
148

149
150
                if isinstance(sampling_params, BeamSearchParams):
                    generator = self.engine_client.beam_search(
151
                        prompt=engine_prompt,
152
153
                        request_id=request_id,
                        params=sampling_params,
154
                    )
155
156
                else:
                    generator = self.engine_client.generate(
157
                        engine_prompt,
158
159
160
161
162
163
164
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        prompt_adapter_request=prompt_adapter_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
                    )
165
166

                generators.append(generator)
167
        except ValueError as e:
168
            # TODO: Use a vllm-specific Validation Error
169
            return self.create_error_response(str(e))
170

171
        result_generator = merge_async_iterators(*generators)
172

173
        model_name = self._get_model_name(lora_request)
174
175
        num_prompts = len(engine_prompts)

176
        # Similar to the OpenAI API, when n != best_of, we do not stream the
177
178
        # results. In addition, we do not stream the results when use
        # beam search.
179
180
181
182
183
184
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
185
186
187
188
189
190
            return self.completion_stream_generator(
                request,
                result_generator,
                request_id,
                created_time,
                model_name,
191
                num_prompts=num_prompts,
192
193
                tokenizer=tokenizer,
                request_metadata=request_metadata)
194
195

        # Non-streaming response
196
        final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
197
198
199
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
200
201
202
203
204
205
206
207

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
208
                    final_res.prompt = request_prompts[i]["prompt"]
209
210
211
212

            final_res_batch_checked = cast(List[RequestOutput],
                                           final_res_batch)

213
            response = self.request_output_to_completion_response(
214
215
216
217
218
219
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
220
                request_metadata,
221
            )
222
223
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
224
225
226
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
227

228
229
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
230
        if request.stream:
231
            response_json = response.model_dump_json()
232
233
234
235
236
237
238
239

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
240
241
242
243
244
245
246
247
248

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
249
        tokenizer: AnyTokenizer,
250
        request_metadata: RequestResponseMetadata,
251
    ) -> AsyncGenerator[str, None]:
252
        num_choices = 1 if request.n is None else request.n
253
        previous_text_lens = [0] * num_choices * num_prompts
254
255
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
256
        num_prompt_tokens = [0] * num_prompts
257

258
259
260
261
262
263
264
265
        stream_options = request.stream_options
        if stream_options:
            include_usage = stream_options.include_usage
            include_continuous_usage = include_usage and \
                                       stream_options.continuous_usage_stats
        else:
            include_usage, include_continuous_usage = False, False

266
267
        try:
            async for prompt_idx, res in result_generator:
268
269
270
271
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
                prompt_text = res.prompt

272
273
274
275
                # Prompt details are excluded from later streamed outputs
                if res.prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)

276
277
278
                delta_token_ids: GenericSequence[int]
                out_logprobs: Optional[GenericSequence[Optional[Dict[
                    int, Logprob]]]]
279
280

                for output in res.outputs:
281
                    i = output.index + prompt_idx * num_choices
282

283
                    assert request.max_tokens is not None
284
                    if request.echo and not has_echoed[i]:
285
                        assert prompt_token_ids is not None
286
                        assert prompt_text is not None
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            assert prompt_logprobs is not None
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
                                *prompt_token_ids, *output.token_ids
                            ]
                            out_logprobs = [
                                *prompt_logprobs,
                                *(output.logprobs or []),
                            ]
303
304
305
                        has_echoed[i] = True
                    else:
                        # return just the delta
306
307
308
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
309

310
311
312
313
314
                        if not delta_text and not delta_token_ids \
                            and not previous_num_tokens[i]:
                            # Chunked prefill case, don't return empty chunks
                            continue

315
                    if request.logprobs is not None:
316
317
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
318
                        logprobs = self._create_completion_logprobs(
319
                            token_ids=delta_token_ids,
320
                            top_logprobs=out_logprobs,
321
                            num_output_top_logprobs=request.logprobs,
322
                            tokenizer=tokenizer,
323
                            initial_text_offset=previous_text_lens[i],
324
325
326
327
                        )
                    else:
                        logprobs = None

328
329
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
330
                    finish_reason = output.finish_reason
331
                    stop_reason = output.stop_reason
332
333

                    chunk = CompletionStreamResponse(
334
335
336
337
338
339
340
341
342
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
343
                                stop_reason=stop_reason,
344
                            )
345
                        ])
346
347
348
349
350
351
352
353
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
354

355
                    response_json = chunk.model_dump_json(exclude_unset=False)
356
                    yield f"data: {response_json}\n\n"
357

358
359
360
361
362
363
364
365
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
                total_tokens=total_prompt_tokens + total_completion_tokens)

            if include_usage:
366
367
368
369
370
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
371
                    usage=final_usage_info,
372
373
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
374
                    exclude_unset=False, exclude_none=True))
375
376
                yield f"data: {final_usage_data}\n\n"

377
            # report to FastAPI middleware aggregate usage across all choices
378
            request_metadata.final_usage_info = final_usage_info
379

380
381
382
383
384
385
386
387
388
389
390
391
392
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
393
        tokenizer: AnyTokenizer,
394
        request_metadata: RequestResponseMetadata,
395
    ) -> CompletionResponse:
396
        choices: List[CompletionResponseChoice] = []
397
398
        num_prompt_tokens = 0
        num_generated_tokens = 0
399

400
401
        for final_res in final_res_batch:
            prompt_token_ids = final_res.prompt_token_ids
402
            assert prompt_token_ids is not None
403
            prompt_logprobs = final_res.prompt_logprobs
404
405
406
407
408
409
            if prompt_logprobs:
                for logprob_dict in prompt_logprobs:
                    if logprob_dict:
                        for logprob_values in logprob_dict.values():
                            if logprob_values.logprob == float('-inf'):
                                logprob_values.logprob = -9999.0
410
411
            prompt_text = final_res.prompt

412
413
414
415
            token_ids: GenericSequence[int]
            out_logprobs: Optional[GenericSequence[Optional[Dict[int,
                                                                 Logprob]]]]

416
            for output in final_res.outputs:
417
                assert request.max_tokens is not None
418
                if request.echo:
419
                    assert prompt_text is not None
420
421
422
423
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
424
                    else:
425
426
427
428
429
430
431
432
433
434
435
436
437
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
438
439
                else:
                    token_ids = output.token_ids
440
                    out_logprobs = output.logprobs
441
442
443
                    output_text = output.text

                if request.logprobs is not None:
444
                    assert out_logprobs is not None, "Did not output logprobs"
445
                    logprobs = self._create_completion_logprobs(
446
                        token_ids=token_ids,
447
                        top_logprobs=out_logprobs,
448
                        tokenizer=tokenizer,
449
450
451
452
453
454
455
456
457
458
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
459
                    stop_reason=output.stop_reason,
460
                    prompt_logprobs=final_res.prompt_logprobs,
461
462
463
                )
                choices.append(choice_data)

464
465
                num_generated_tokens += len(output.token_ids)

466
467
468
469
470
471
472
473
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

474
475
        request_metadata.final_usage_info = usage

476
477
478
479
480
481
482
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
483
484
485
486
487
488

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
489
        tokenizer: AnyTokenizer,
490
491
492
493
494
495
496
497
498
499
500
501
502
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
503
                token = tokenizer.decode(token_id)
504
505
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
506

507
508
509
510
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
511
512
                step_token = step_top_logprobs[token_id]

513
                token = self._get_decoded_token(
514
                    step_token,
515
516
                    token_id,
                    tokenizer,
517
518
519
520
                    return_as_token_id=self.return_tokens_as_token_ids,
                )
                token_logprob = max(step_token.logprob, -9999.0)

521
522
523
524
525
526
527
528
529
530
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
531
532
533
534
535
                    self._get_decoded_token(
                        top_lp[1],
                        top_lp[0],
                        tokenizer,
                        return_as_token_id=self.return_tokens_as_token_ids):
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )