serving_completion.py 22.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import time
5
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
6
from typing import Sequence as GenericSequence
7
from typing import Tuple, Union, cast
8

9
from fastapi import Request
10

11
from vllm.config import ModelConfig
12
from vllm.engine.protocol import EngineClient
13
from vllm.entrypoints.logger import RequestLogger
14
# yapf conflicts with isort for this block
15
16
17
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
18
19
20
21
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
22
23
24
                                              ErrorResponse,
                                              RequestResponseMetadata,
                                              UsageInfo)
25
# yapf: enable
26
27
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
28
29
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
30
from vllm.sampling_params import BeamSearchParams, SamplingParams
31
from vllm.sequence import Logprob
32
from vllm.transformers_utils.tokenizer import AnyTokenizer
33
from vllm.utils import merge_async_iterators
34
35
36
37
38
39

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):

40
41
    def __init__(
        self,
42
        engine_client: EngineClient,
43
        model_config: ModelConfig,
44
        models: OpenAIServingModels,
45
46
        *,
        request_logger: Optional[RequestLogger],
47
        return_tokens_as_token_ids: bool = False,
48
    ):
49
        super().__init__(engine_client=engine_client,
50
                         model_config=model_config,
51
                         models=models,
52
53
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
54
55
56
57
58
        diff_sampling_param = self.model_config.get_diff_sampling_param()
        if diff_sampling_param:
            logger.info(
                "Overwriting default completion sampling param with: %s",
                diff_sampling_param)
59

60
61
62
    async def create_completion(
        self,
        request: CompletionRequest,
63
        raw_request: Optional[Request] = None,
64
    ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
65
66
67
68
69
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

70
        NOTE: Currently we do not support the following feature:
71
72
73
74
75
76
77
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

78
79
80
81
82
83
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

84
        # Return error for unsupported features.
85
86
87
88
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

89
        request_id = f"cmpl-{self._base_request_id(raw_request)}"
90
        created_time = int(time.time())
91

92
93
94
95
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

96
        try:
97
98
99
100
101
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

102
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
103

104
            request_prompts, engine_prompts = await self._preprocess_completion(
105
106
107
108
109
110
111
112
113
                request,
                tokenizer,
                request.prompt,
                truncate_prompt_tokens=request.truncate_prompt_tokens,
                add_special_tokens=request.add_special_tokens,
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
114

115
116
117
118
        # Schedule the request and get the result generator.
        generators: List[AsyncGenerator[RequestOutput, None]] = []
        try:
            for i, engine_prompt in enumerate(engine_prompts):
119
120
                sampling_params: Union[SamplingParams, BeamSearchParams]
                default_max_tokens = self.max_model_len - len(
121
                    engine_prompt["prompt_token_ids"])
122
123
124
                # Build default sampling params
                default_sampling_params = (
                    self.model_config.get_diff_sampling_param())
125
126
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
127
                        default_max_tokens, default_sampling_params)
128
129
                else:
                    sampling_params = request.to_sampling_params(
130
                        default_max_tokens,
131
132
                        self.model_config.logits_processor_pattern,
                        default_sampling_params)
133

134
135
136
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
137
                                 request_prompts[i],
138
139
140
                                 params=sampling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)
141

142
                trace_headers = (None if raw_request is None else await
143
                                 self._get_trace_headers(raw_request.headers))
144

145
146
                if isinstance(sampling_params, BeamSearchParams):
                    generator = self.engine_client.beam_search(
147
                        prompt=engine_prompt,
148
149
                        request_id=request_id,
                        params=sampling_params,
150
                    )
151
152
                else:
                    generator = self.engine_client.generate(
153
                        engine_prompt,
154
155
156
157
158
159
160
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        prompt_adapter_request=prompt_adapter_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
                    )
161
162

                generators.append(generator)
163
        except ValueError as e:
164
            # TODO: Use a vllm-specific Validation Error
165
            return self.create_error_response(str(e))
166

167
        result_generator = merge_async_iterators(*generators)
168

169
        model_name = self._get_model_name(request.model, lora_request)
170
171
        num_prompts = len(engine_prompts)

172
        # Similar to the OpenAI API, when n != best_of, we do not stream the
173
174
        # results. In addition, we do not stream the results when use
        # beam search.
175
176
177
178
179
180
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
181
182
183
184
185
186
            return self.completion_stream_generator(
                request,
                result_generator,
                request_id,
                created_time,
                model_name,
187
                num_prompts=num_prompts,
188
189
                tokenizer=tokenizer,
                request_metadata=request_metadata)
190
191

        # Non-streaming response
192
        final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
193
194
195
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
196
197
198
199
200
201
202
203

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
204
                    final_res.prompt = request_prompts[i]["prompt"]
205
206
207
208

            final_res_batch_checked = cast(List[RequestOutput],
                                           final_res_batch)

209
            response = self.request_output_to_completion_response(
210
211
212
213
214
215
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
216
                request_metadata,
217
            )
218
219
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
220
221
222
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
223

224
225
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
226
        if request.stream:
227
            response_json = response.model_dump_json()
228
229
230
231
232
233
234
235

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
236
237
238
239
240
241
242
243
244

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
245
        tokenizer: AnyTokenizer,
246
        request_metadata: RequestResponseMetadata,
247
    ) -> AsyncGenerator[str, None]:
248
        num_choices = 1 if request.n is None else request.n
249
        previous_text_lens = [0] * num_choices * num_prompts
250
251
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
252
        num_prompt_tokens = [0] * num_prompts
253

254
255
256
257
258
259
260
261
        stream_options = request.stream_options
        if stream_options:
            include_usage = stream_options.include_usage
            include_continuous_usage = include_usage and \
                                       stream_options.continuous_usage_stats
        else:
            include_usage, include_continuous_usage = False, False

262
263
        try:
            async for prompt_idx, res in result_generator:
264
265
266
267
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
                prompt_text = res.prompt

268
269
270
271
                # Prompt details are excluded from later streamed outputs
                if res.prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)

272
273
274
                delta_token_ids: GenericSequence[int]
                out_logprobs: Optional[GenericSequence[Optional[Dict[
                    int, Logprob]]]]
275
276

                for output in res.outputs:
277
                    i = output.index + prompt_idx * num_choices
278

279
                    assert request.max_tokens is not None
280
                    if request.echo and not has_echoed[i]:
281
                        assert prompt_token_ids is not None
282
                        assert prompt_text is not None
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            assert prompt_logprobs is not None
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
                                *prompt_token_ids, *output.token_ids
                            ]
                            out_logprobs = [
                                *prompt_logprobs,
                                *(output.logprobs or []),
                            ]
299
300
301
                        has_echoed[i] = True
                    else:
                        # return just the delta
302
303
304
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
305

306
307
308
309
310
                        if not delta_text and not delta_token_ids \
                            and not previous_num_tokens[i]:
                            # Chunked prefill case, don't return empty chunks
                            continue

311
                    if request.logprobs is not None:
312
313
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
314
                        logprobs = self._create_completion_logprobs(
315
                            token_ids=delta_token_ids,
316
                            top_logprobs=out_logprobs,
317
                            num_output_top_logprobs=request.logprobs,
318
                            tokenizer=tokenizer,
319
                            initial_text_offset=previous_text_lens[i],
320
321
322
323
                        )
                    else:
                        logprobs = None

324
325
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
326
                    finish_reason = output.finish_reason
327
                    stop_reason = output.stop_reason
328
329

                    chunk = CompletionStreamResponse(
330
331
332
333
334
335
336
337
338
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
339
                                stop_reason=stop_reason,
340
                            )
341
                        ])
342
343
344
345
346
347
348
349
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
350

351
                    response_json = chunk.model_dump_json(exclude_unset=False)
352
                    yield f"data: {response_json}\n\n"
353

354
355
356
357
358
359
360
361
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
                total_tokens=total_prompt_tokens + total_completion_tokens)

            if include_usage:
362
363
364
365
366
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
367
                    usage=final_usage_info,
368
369
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
370
                    exclude_unset=False, exclude_none=True))
371
372
                yield f"data: {final_usage_data}\n\n"

373
            # report to FastAPI middleware aggregate usage across all choices
374
            request_metadata.final_usage_info = final_usage_info
375

376
        except Exception as e:
377
378
379
380
381
382
383
384
385
386
387
388
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
389
        tokenizer: AnyTokenizer,
390
        request_metadata: RequestResponseMetadata,
391
    ) -> CompletionResponse:
392
        choices: List[CompletionResponseChoice] = []
393
394
        num_prompt_tokens = 0
        num_generated_tokens = 0
395

396
397
        for final_res in final_res_batch:
            prompt_token_ids = final_res.prompt_token_ids
398
            assert prompt_token_ids is not None
399
            prompt_logprobs = final_res.prompt_logprobs
400
401
402
403
404
405
            if prompt_logprobs:
                for logprob_dict in prompt_logprobs:
                    if logprob_dict:
                        for logprob_values in logprob_dict.values():
                            if logprob_values.logprob == float('-inf'):
                                logprob_values.logprob = -9999.0
406
407
            prompt_text = final_res.prompt

408
409
410
411
            token_ids: GenericSequence[int]
            out_logprobs: Optional[GenericSequence[Optional[Dict[int,
                                                                 Logprob]]]]

412
            for output in final_res.outputs:
413
                assert request.max_tokens is not None
414
                if request.echo:
415
                    assert prompt_text is not None
416
417
418
419
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
420
                    else:
421
422
423
424
425
426
427
428
429
430
431
432
433
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
434
435
                else:
                    token_ids = output.token_ids
436
                    out_logprobs = output.logprobs
437
438
439
                    output_text = output.text

                if request.logprobs is not None:
440
                    assert out_logprobs is not None, "Did not output logprobs"
441
                    logprobs = self._create_completion_logprobs(
442
                        token_ids=token_ids,
443
                        top_logprobs=out_logprobs,
444
                        tokenizer=tokenizer,
445
446
447
448
449
450
451
452
453
454
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
455
                    stop_reason=output.stop_reason,
456
                    prompt_logprobs=final_res.prompt_logprobs,
457
458
459
                )
                choices.append(choice_data)

460
461
                num_generated_tokens += len(output.token_ids)

462
463
464
465
466
467
468
469
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

470
471
        request_metadata.final_usage_info = usage

472
473
474
475
476
477
478
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
479
480
481
482
483
484

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
485
        tokenizer: AnyTokenizer,
486
487
488
489
490
491
492
493
494
495
496
497
498
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
499
                token = tokenizer.decode(token_id)
500
501
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
502

503
504
505
506
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
507
508
                step_token = step_top_logprobs[token_id]

509
                token = self._get_decoded_token(
510
                    step_token,
511
512
                    token_id,
                    tokenizer,
513
514
515
516
                    return_as_token_id=self.return_tokens_as_token_ids,
                )
                token_logprob = max(step_token.logprob, -9999.0)

517
518
519
520
521
522
523
524
525
526
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
527
528
529
530
                    self._get_decoded_token(top_lp[1],
                                            top_lp[0],
                                            tokenizer,
                                            return_as_token_id=self.return_tokens_as_token_ids):
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )