serving.py 25.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
from fastapi import Request
11

12
from vllm.engine.protocol import EngineClient
13
from vllm.entrypoints.logger import RequestLogger
14
from vllm.entrypoints.openai.completion.protocol import (
15
16
17
18
19
20
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
21
22
)
from vllm.entrypoints.openai.engine.protocol import (
23
24
25
26
27
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
28
from vllm.entrypoints.openai.engine.serving import (
29
30
31
32
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
33
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
34
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
35
from vllm.exceptions import VLLMValidationError
36
from vllm.inputs.data import ProcessorInputs
37
from vllm.logger import init_logger
38
from vllm.logprobs import Logprob
39
from vllm.outputs import RequestOutput
40
from vllm.sampling_params import BeamSearchParams, SamplingParams
41
from vllm.tokenizers import TokenizerLike
42
43
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
44
45
46
47
48

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
49
50
    def __init__(
        self,
51
        engine_client: EngineClient,
52
        models: OpenAIServingModels,
53
        *,
54
        request_logger: RequestLogger | None,
55
        return_tokens_as_token_ids: bool = False,
56
        enable_prompt_tokens_details: bool = False,
57
        enable_force_include_usage: bool = False,
58
    ):
59
60
61
62
63
64
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
        )
65

66
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
67
        self.enable_force_include_usage = enable_force_include_usage
68
69

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
70
71
72
73
74
75
        mc = self.model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )
76

77
    async def render_completion_request(
78
79
        self,
        request: CompletionRequest,
80
    ) -> list[ProcessorInputs] | ErrorResponse:
81
82
        """
        render completion request by validating and preprocessing inputs.
83

84
85
86
        Returns:
            A list of engine_prompts on success,
            or an ErrorResponse on failure.
87
88
89
90
91
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

92
93
94
95
96
97
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

98
        # Return error for unsupported features.
99
        if request.suffix is not None:
100
            return self.create_error_response("suffix is not currently supported")
101

102
        if request.echo and request.prompt_embeds is not None:
103
            return self.create_error_response("Echo is unsupported with prompt embeds.")
104

105
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
106
            return self.create_error_response(
107
108
                "prompt_logprobs is not compatible with prompt embeds."
            )
109

110
111
112
113
114
        engine_prompts = await self._preprocess_completion(
            request,
            prompt_input=request.prompt,
            prompt_embeds=request.prompt_embeds,
        )
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        return engine_prompts

    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

        NOTE: Currently we do not support the following feature:
            - suffix (the language models we currently support do not support
            suffix)
        """
        result = await self.render_completion_request(request)
        if isinstance(result, ErrorResponse):
            return result

        engine_prompts = result

        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
        created_time = int(time.time())

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

145
        lora_request = self._maybe_get_adapters(request)
146

147
148
149
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

150
        # Schedule the request and get the result generator.
151
        max_model_len = self.model_config.max_model_len
152
        generators: list[AsyncGenerator[RequestOutput, None]] = []
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        for i, engine_prompt in enumerate(engine_prompts):
            max_tokens = get_max_tokens(
                max_model_len,
                request.max_tokens,
                self._extract_prompt_len(engine_prompt),
                self.default_sampling_params,
                self.override_max_tokens,
            )

            sampling_params: SamplingParams | BeamSearchParams
            if request.use_beam_search:
                sampling_params = request.to_beam_search_params(
                    max_tokens, self.default_sampling_params
                )
            else:
                sampling_params = request.to_sampling_params(
                    max_tokens,
170
                    self.default_sampling_params,
171
                )
172

173
            request_id_item = f"{request_id}-{i}"
174

175
176
177
178
179
180
            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
181

182
183
184
185
186
187
188
189
190
191
            trace_headers = (
                None
                if raw_request is None
                else await self._get_trace_headers(raw_request.headers)
            )

            if isinstance(sampling_params, BeamSearchParams):
                generator = self.beam_search(
                    prompt=engine_prompt,
                    request_id=request_id,
192
193
                    params=sampling_params,
                    lora_request=lora_request,
194
                    trace_headers=trace_headers,
195
                )
196
197
198
199
200
201
202
203
204
            else:
                generator = self.engine_client.generate(
                    engine_prompt,
                    sampling_params,
                    request_id_item,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=request.priority,
                    data_parallel_rank=data_parallel_rank,
205
                )
206

207
            generators.append(generator)
208

209
        result_generator = merge_async_iterators(*generators)
210

211
        model_name = self.models.model_name(lora_request)
212
213
        num_prompts = len(engine_prompts)

214
215
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
216
217

        # Streaming response
218
219
        tokenizer = self.renderer.tokenizer

220
        if stream:
221
222
            return self.completion_stream_generator(
                request,
223
                engine_prompts,
224
225
226
227
                result_generator,
                request_id,
                created_time,
                model_name,
228
                num_prompts=num_prompts,
229
                tokenizer=tokenizer,
230
                request_metadata=request_metadata,
231
            )
232
233

        # Non-streaming response
234
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
235
236
237
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
238
239
240
241
242
243
244
245

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
246
                    engine_prompt = engine_prompts[i]
247
                    final_res.prompt = self._extract_prompt_text(engine_prompt)
248

249
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
250

251
            response = self.request_output_to_completion_response(
252
253
254
255
256
257
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
258
                request_metadata,
259
            )
260
261
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
262

263
264
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
265
        if request.stream:
266
            response_json = response.model_dump_json()
267
268
269
270
271
272
273
274

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
275
276
277
278

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
279
        engine_prompts: list[ProcessorInputs],
280
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
281
282
283
284
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
285
        tokenizer: TokenizerLike | None,
286
        request_metadata: RequestResponseMetadata,
287
    ) -> AsyncGenerator[str, None]:
288
        num_choices = 1 if request.n is None else request.n
289
        previous_text_lens = [0] * num_choices * num_prompts
290
291
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
292
        num_prompt_tokens = [0] * num_prompts
293
294
        num_cached_tokens = None
        first_iteration = True
295

296
        stream_options = request.stream_options
297
298
299
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
300

301
302
        try:
            async for prompt_idx, res in result_generator:
303
304
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
305

306
307
308
309
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

310
311
312
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
313
                    prompt_text = self._extract_prompt_text(engine_prompt)
314

315
                # Prompt details are excluded from later streamed outputs
316
317
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
318

319
                delta_token_ids: GenericSequence[int]
320
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
321
322

                for output in res.outputs:
323
                    i = output.index + prompt_idx * num_choices
324

325
326
327
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
328
                    prompt_token_ids_to_return: list[int] | None = None
329

330
                    assert request.max_tokens is not None
331
                    if request.echo and not has_echoed[i]:
332
                        assert prompt_token_ids is not None
333
334
                        if request.return_token_ids:
                            prompt_text = ""
335
                        assert prompt_text is not None
336
337
338
339
340
341
342
343
344
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
345
346
                                *prompt_token_ids,
                                *output.token_ids,
347
348
                            ]
                            out_logprobs = [
349
                                *(prompt_logprobs or []),
350
351
                                *(output.logprobs or []),
                            ]
352
                        prompt_token_ids_to_return = prompt_token_ids
353
354
355
                        has_echoed[i] = True
                    else:
                        # return just the delta
356
357
358
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
359

360
361
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
362
                        if not has_echoed[i] and request.return_token_ids:
363
364
365
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

366
367
368
369
370
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
371
372
373
                            # Chunked prefill case, don't return empty chunks
                            continue

374
                    if request.logprobs is not None:
375
                        assert out_logprobs is not None, "Did not output logprobs"
376
                        logprobs = self._create_completion_logprobs(
377
                            token_ids=delta_token_ids,
378
                            top_logprobs=out_logprobs,
379
                            num_output_top_logprobs=request.logprobs,
380
                            tokenizer=tokenizer,
381
                            initial_text_offset=previous_text_lens[i],
382
                            return_as_token_id=request.return_tokens_as_token_ids,
383
384
385
386
                        )
                    else:
                        logprobs = None

387
388
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
389
                    finish_reason = output.finish_reason
390
                    stop_reason = output.stop_reason
391

392
393
                    self._raise_if_error(finish_reason, request_id)

394
                    chunk = CompletionStreamResponse(
395
396
397
398
399
400
401
402
403
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
404
                                stop_reason=stop_reason,
405
                                prompt_token_ids=prompt_token_ids_to_return,
406
407
408
409
410
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
411
                            )
412
413
                        ],
                    )
414
415
416
417
418
419
420
421
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
422

423
                    response_json = chunk.model_dump_json(exclude_unset=False)
424
                    yield f"data: {response_json}\n\n"
425

426
427
428
429
430
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
431
432
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
433

434
435
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
436
437
                    cached_tokens=num_cached_tokens
                )
438

439
            if include_usage:
440
441
442
443
444
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
445
                    usage=final_usage_info,
446
                )
447
                final_usage_data = final_usage_chunk.model_dump_json(
448
449
                    exclude_unset=False, exclude_none=True
                )
450
451
                yield f"data: {final_usage_data}\n\n"

452
            # report to FastAPI middleware aggregate usage across all choices
453
            request_metadata.final_usage_info = final_usage_info
454

455
456
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
457
        except Exception as e:
458
            logger.exception("Error in completion stream generator.")
459
            data = self.create_streaming_error_response(e)
460
461
462
463
464
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
465
        final_res_batch: list[RequestOutput],
466
467
468
469
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
470
        tokenizer: TokenizerLike | None,
471
        request_metadata: RequestResponseMetadata,
472
    ) -> CompletionResponse:
473
        choices: list[CompletionResponseChoice] = []
474
475
        num_prompt_tokens = 0
        num_generated_tokens = 0
476
477
        kv_transfer_params = None
        last_final_res = None
478
        for final_res in final_res_batch:
479
            last_final_res = final_res
480
            prompt_token_ids = final_res.prompt_token_ids
481
            assert prompt_token_ids is not None
482
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
483
484
            prompt_text = final_res.prompt

485
            token_ids: GenericSequence[int]
486
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
487

488
            for output in final_res.outputs:
489
490
                self._raise_if_error(output.finish_reason, request_id)

491
                assert request.max_tokens is not None
492
                if request.echo:
493
494
                    if request.return_token_ids:
                        prompt_text = ""
495
                    assert prompt_text is not None
496
497
498
499
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
500
                    else:
501
502
503
504
505
506
507
508
509
510
511
512
513
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
514
515
                else:
                    token_ids = output.token_ids
516
                    out_logprobs = output.logprobs
517
518
519
                    output_text = output.text

                if request.logprobs is not None:
520
                    assert out_logprobs is not None, "Did not output logprobs"
521
                    logprobs = self._create_completion_logprobs(
522
                        token_ids=token_ids,
523
                        top_logprobs=out_logprobs,
524
                        tokenizer=tokenizer,
525
                        num_output_top_logprobs=request.logprobs,
526
                        return_as_token_id=request.return_tokens_as_token_ids,
527
528
529
530
531
532
533
534
535
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
536
                    stop_reason=output.stop_reason,
537
                    prompt_logprobs=final_res.prompt_logprobs,
538
539
540
541
542
543
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
544
545
546
                )
                choices.append(choice_data)

547
548
                num_generated_tokens += len(output.token_ids)

549
550
551
552
553
554
555
556
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

557
558
559
560
561
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
562
            usage.prompt_tokens_details = PromptTokenUsageInfo(
563
564
                cached_tokens=last_final_res.num_cached_tokens
            )
565

566
        request_metadata.final_usage_info = usage
567
568
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
569
570
571
572
573
574
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
575
576
            kv_transfer_params=kv_transfer_params,
        )
577
578
579
580

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
581
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
582
        num_output_top_logprobs: int,
583
        tokenizer: TokenizerLike | None,
584
        initial_text_offset: int = 0,
585
        return_as_token_id: bool | None = None,
586
587
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
588
        out_text_offset: list[int] = []
589
        out_token_logprobs: list[float | None] = []
590
        out_tokens: list[str] = []
591
        out_top_logprobs: list[dict[str, float] | None] = []
592
593
594

        last_token_len = 0

595
596
597
598
599
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
600
601
602
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
603
                if should_return_as_token_id:
604
                    token = f"token_id:{token_id}"
605
606
                else:
                    if tokenizer is None:
607
608
609
610
611
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
612
613
614
                        )

                    token = tokenizer.decode(token_id)
615

616
617
618
619
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
620
621
                step_token = step_top_logprobs[token_id]

622
                token = self._get_decoded_token(
623
                    step_token,
624
625
                    token_id,
                    tokenizer,
626
                    return_as_token_id=should_return_as_token_id,
627
628
629
                )
                token_logprob = max(step_token.logprob, -9999.0)

630
631
632
633
634
635
636
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
651
652
653
654
655
656
657
658
659
660
661
662
663

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )