serving_completion.py 22.4 KB
Newer Older
1
import asyncio
2
import time
3
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
4
5
                    Optional)
from typing import Sequence as GenericSequence
6
from typing import Tuple, Union, cast
7

8
from fastapi import Request
9

10
from vllm.config import ModelConfig
11
from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.logger import RequestLogger
13
# yapf conflicts with isort for this block
14
15
16
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
17
18
19
20
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
21
22
23
                                              ErrorResponse,
                                              RequestResponseMetadata,
                                              UsageInfo)
24
# yapf: enable
25
26
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
                                                    LoRAModulePath,
27
28
                                                    OpenAIServing,
                                                    PromptAdapterPath)
29
30
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
31
from vllm.sampling_params import BeamSearchParams, SamplingParams
32
from vllm.sequence import Logprob
33
34
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
35
from vllm.transformers_utils.tokenizer import AnyTokenizer
36
from vllm.utils import merge_async_iterators, random_uuid
37
38
39

logger = init_logger(__name__)

Simon Mo's avatar
Simon Mo committed
40
41
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
42
TypeCreateLogProbsFn = Callable[
43
    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
44

45
46
47

class OpenAIServingCompletion(OpenAIServing):

48
49
    def __init__(
        self,
50
        engine_client: EngineClient,
51
        model_config: ModelConfig,
52
        base_model_paths: List[BaseModelPath],
53
54
55
56
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        prompt_adapters: Optional[List[PromptAdapterPath]],
        request_logger: Optional[RequestLogger],
57
        return_tokens_as_token_ids: bool = False,
58
    ):
59
        super().__init__(engine_client=engine_client,
60
                         model_config=model_config,
61
                         base_model_paths=base_model_paths,
62
                         lora_modules=lora_modules,
63
                         prompt_adapters=prompt_adapters,
64
65
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
66

67
68
69
70
71
    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request,
    ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
72
73
74
75
76
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

77
        NOTE: Currently we do not support the following feature:
78
79
80
81
82
83
84
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

85
86
87
88
89
90
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

91
        # Return error for unsupported features.
92
93
94
95
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

96
        model_name = self.base_model_paths[0].name
97
        request_id = f"cmpl-{random_uuid()}"
98
        created_time = int(time.time())
99

100
101
102
103
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

104
        # Schedule the request and get the result generator.
105
        generators: List[AsyncGenerator[RequestOutput, None]] = []
106
        try:
107
108
109
110
111
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

112
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
113

114
115
            prompts = list(
                self._tokenize_prompt_input_or_inputs(
116
117
                    request,
                    tokenizer,
118
                    request.prompt,
119
                    truncate_prompt_tokens=request.truncate_prompt_tokens,
120
121
122
123
                    add_special_tokens=request.add_special_tokens,
                ))

            for i, prompt_inputs in enumerate(prompts):
124
125
126
127
128
129
130
131
132
                sampling_params: Union[SamplingParams, BeamSearchParams]
                default_max_tokens = self.max_model_len - len(
                    prompt_inputs["prompt_token_ids"])
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
                        default_max_tokens)
                else:
                    sampling_params = request.to_sampling_params(
                        default_max_tokens)
133

134
135
136
137
138
139
140
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
                                 prompt_inputs,
                                 params=sampling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)
141

142
143
                is_tracing_enabled = (await
                                      self.engine_client.is_tracing_enabled())
144
145
146
147
148
149
150
                trace_headers = None
                if is_tracing_enabled:
                    trace_headers = extract_trace_headers(raw_request.headers)
                if not is_tracing_enabled and contains_trace_headers(
                        raw_request.headers):
                    log_tracing_disabled_warning()

151
152
                if isinstance(sampling_params, BeamSearchParams):
                    generator = self.engine_client.beam_search(
153
154
155
156
                        prompt_inputs["prompt_token_ids"],
                        request_id_item,
                        sampling_params,
                    )
157
158
159
160
161
162
163
164
165
166
167
168
169
                else:
                    generator = self.engine_client.generate(
                        {
                            "prompt_token_ids":
                            prompt_inputs["prompt_token_ids"]
                        },
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        prompt_adapter_request=prompt_adapter_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
                    )
170
171

                generators.append(generator)
172
        except ValueError as e:
173
            # TODO: Use a vllm-specific Validation Error
174
            return self.create_error_response(str(e))
175

176
177
        result_generator = merge_async_iterators(
            *generators, is_cancelled=raw_request.is_disconnected)
178

179
        # Similar to the OpenAI API, when n != best_of, we do not stream the
180
181
        # results. In addition, we do not stream the results when use
        # beam search.
182
183
184
185
186
187
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
188
189
190
191
192
193
194
195
196
            return self.completion_stream_generator(
                request,
                result_generator,
                request_id,
                created_time,
                model_name,
                num_prompts=len(prompts),
                tokenizer=tokenizer,
                request_metadata=request_metadata)
197
198

        # Non-streaming response
199
        final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
200
201
202
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
203
204
205
206
207
208
209
210
211
212
213
214
215

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
                    final_res.prompt = prompts[i]["prompt"]

            final_res_batch_checked = cast(List[RequestOutput],
                                           final_res_batch)

216
            response = self.request_output_to_completion_response(
217
218
219
220
221
222
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
223
                request_metadata,
224
            )
225
226
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
227
228
229
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
230

231
232
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
233
        if request.stream:
234
            response_json = response.model_dump_json()
235
236
237
238
239
240
241
242

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
243
244
245
246
247
248
249
250
251

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
252
        tokenizer: AnyTokenizer,
253
        request_metadata: RequestResponseMetadata,
254
    ) -> AsyncGenerator[str, None]:
255
        num_choices = 1 if request.n is None else request.n
256
        previous_text_lens = [0] * num_choices * num_prompts
257
258
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
259
        num_prompt_tokens = [0] * num_prompts
260
261
262

        try:
            async for prompt_idx, res in result_generator:
263
264
265
266
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
                prompt_text = res.prompt

267
268
269
270
                # Prompt details are excluded from later streamed outputs
                if res.prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)

271
272
273
                delta_token_ids: GenericSequence[int]
                out_logprobs: Optional[GenericSequence[Optional[Dict[
                    int, Logprob]]]]
274
275

                for output in res.outputs:
276
                    i = output.index + prompt_idx * num_choices
277
278
                    # TODO(simon): optimize the performance by avoiding full
                    # text O(n^2) sending.
279

280
                    assert request.max_tokens is not None
281
                    if request.echo and request.max_tokens == 0:
282
                        assert prompt_token_ids is not None
283
                        assert prompt_text is not None
284
                        # only return the prompt
285
286
287
                        delta_text = prompt_text
                        delta_token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
288
                        has_echoed[i] = True
289
290
                    elif (request.echo and request.max_tokens > 0
                          and not has_echoed[i]):
291
                        assert prompt_token_ids is not None
292
293
                        assert prompt_text is not None
                        assert prompt_logprobs is not None
294
                        # echo the prompt and first token
295
296
297
298
299
300
301
302
                        delta_text = prompt_text + output.text
                        delta_token_ids = [
                            *prompt_token_ids, *output.token_ids
                        ]
                        out_logprobs = [
                            *prompt_logprobs,
                            *(output.logprobs or []),
                        ]
303
304
305
                        has_echoed[i] = True
                    else:
                        # return just the delta
306
307
308
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
309
310

                    if request.logprobs is not None:
311
312
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
313
                        logprobs = self._create_completion_logprobs(
314
                            token_ids=delta_token_ids,
315
                            top_logprobs=out_logprobs,
316
                            num_output_top_logprobs=request.logprobs,
317
                            tokenizer=tokenizer,
318
                            initial_text_offset=previous_text_lens[i],
319
320
321
322
                        )
                    else:
                        logprobs = None

323
324
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
325
                    finish_reason = output.finish_reason
326
                    stop_reason = output.stop_reason
327
328

                    chunk = CompletionStreamResponse(
329
330
331
332
333
334
335
336
337
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
338
                                stop_reason=stop_reason,
339
                            )
340
341
342
                        ])
                    if (request.stream_options
                            and request.stream_options.include_usage):
343
344
                        if (request.stream_options.continuous_usage_stats
                                or output.finish_reason is not None):
345
346
                            prompt_tokens = num_prompt_tokens[prompt_idx]
                            completion_tokens = previous_num_tokens[i]
347
348
349
350
351
352
353
354
355
                            usage = UsageInfo(
                                prompt_tokens=prompt_tokens,
                                completion_tokens=completion_tokens,
                                total_tokens=prompt_tokens + completion_tokens,
                            )
                        if request.stream_options.continuous_usage_stats:
                            chunk.usage = usage
                        else:
                            chunk.usage = None
356

357
                    response_json = chunk.model_dump_json(exclude_unset=False)
358
                    yield f"data: {response_json}\n\n"
359
360
361
362
363
364
365
366

            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
367
                    usage=usage,
368
369
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
370
                    exclude_unset=False, exclude_none=True))
371
372
                yield f"data: {final_usage_data}\n\n"

373
374
375
376
377
378
379
380
            # report to FastAPI middleware aggregate usage across all choices
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            request_metadata.final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
                total_tokens=total_prompt_tokens + total_completion_tokens)

381
382
383
384
385
386
387
388
389
390
391
392
393
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
394
        tokenizer: AnyTokenizer,
395
        request_metadata: RequestResponseMetadata,
396
    ) -> CompletionResponse:
397
        choices: List[CompletionResponseChoice] = []
398
399
        num_prompt_tokens = 0
        num_generated_tokens = 0
400

401
402
        for final_res in final_res_batch:
            prompt_token_ids = final_res.prompt_token_ids
403
            assert prompt_token_ids is not None
404
405
406
            prompt_logprobs = final_res.prompt_logprobs
            prompt_text = final_res.prompt

407
408
409
410
            token_ids: GenericSequence[int]
            out_logprobs: Optional[GenericSequence[Optional[Dict[int,
                                                                 Logprob]]]]

411
            for output in final_res.outputs:
412
                assert request.max_tokens is not None
413
                if request.echo and request.max_tokens == 0:
414
                    assert prompt_text is not None
415
                    token_ids = prompt_token_ids
416
                    out_logprobs = prompt_logprobs
417
418
                    output_text = prompt_text
                elif request.echo and request.max_tokens > 0:
419
420
421
422
423
424
425
426
427
428
429
430
431
                    assert prompt_text is not None
                    token_ids = [*prompt_token_ids, *output.token_ids]

                    if request.logprobs is None:
                        out_logprobs = None
                    else:
                        assert prompt_logprobs is not None
                        assert output.logprobs is not None
                        out_logprobs = [
                            *prompt_logprobs,
                            *output.logprobs,
                        ]

432
433
434
                    output_text = prompt_text + output.text
                else:
                    token_ids = output.token_ids
435
                    out_logprobs = output.logprobs
436
437
438
                    output_text = output.text

                if request.logprobs is not None:
439
                    assert out_logprobs is not None, "Did not output logprobs"
440
                    logprobs = self._create_completion_logprobs(
441
                        token_ids=token_ids,
442
                        top_logprobs=out_logprobs,
443
                        tokenizer=tokenizer,
444
445
446
447
448
449
450
451
452
453
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
454
                    stop_reason=output.stop_reason,
455
                    prompt_logprobs=final_res.prompt_logprobs,
456
457
458
                )
                choices.append(choice_data)

459
460
                num_generated_tokens += len(output.token_ids)

461
462
463
464
465
466
467
468
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

469
470
        request_metadata.final_usage_info = usage

471
472
473
474
475
476
477
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
478
479
480
481
482
483

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
484
        tokenizer: AnyTokenizer,
485
486
487
488
489
490
491
492
493
494
495
496
497
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
498
                token = tokenizer.decode(token_id)
499
500
                if self.return_tokens_as_token_ids:
                    token = f"token_id:{token_id}"
501

502
503
504
505
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
506
507
                step_token = step_top_logprobs[token_id]

508
                token = self._get_decoded_token(
509
                    step_token,
510
511
                    token_id,
                    tokenizer,
512
513
514
515
                    return_as_token_id=self.return_tokens_as_token_ids,
                )
                token_logprob = max(step_token.logprob, -9999.0)

516
517
518
519
520
521
522
523
524
525
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
526
527
528
529
530
                    self._get_decoded_token(
                        top_lp[1],
                        top_lp[0],
                        tokenizer,
                        return_as_token_id=self.return_tokens_as_token_ids):
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )