serving.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import TYPE_CHECKING, cast
9

10
from fastapi import Request
11

12
from vllm.engine.protocol import EngineClient
13
from vllm.entrypoints.logger import RequestLogger
14
from vllm.entrypoints.openai.completion.protocol import (
15
16
17
18
19
20
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
21
22
)
from vllm.entrypoints.openai.engine.protocol import (
23
24
25
26
27
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
28
from vllm.entrypoints.openai.engine.serving import (
29
30
31
32
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
33
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
34
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
35
from vllm.exceptions import VLLMValidationError
36
from vllm.inputs import EngineInput
37
from vllm.logger import init_logger
38
from vllm.logprobs import Logprob
39
from vllm.outputs import RequestOutput
40
from vllm.sampling_params import BeamSearchParams, SamplingParams
41
from vllm.tokenizers import TokenizerLike
42
43
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
44

45
46
47
if TYPE_CHECKING:
    from vllm.entrypoints.serve.render.serving import OpenAIServingRender

48
49
50
51
logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
52
53
    def __init__(
        self,
54
        engine_client: EngineClient,
55
        models: OpenAIServingModels,
56
        *,
57
        openai_serving_render: "OpenAIServingRender",
58
        request_logger: RequestLogger | None,
59
        return_tokens_as_token_ids: bool = False,
60
        enable_prompt_tokens_details: bool = False,
61
        enable_force_include_usage: bool = False,
62
    ):
63
64
65
66
67
68
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
        )
69

70
        self.openai_serving_render = openai_serving_render
71
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
72
        self.enable_force_include_usage = enable_force_include_usage
73
74

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
75
76
77
78
79
80
        mc = self.model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )
81

82
    async def render_completion_request(
83
84
        self,
        request: CompletionRequest,
85
    ) -> list[EngineInput] | ErrorResponse:
86
        """
87
88
89
90
        Validate the model and preprocess a completion request.

        Delegates preprocessing logic to OpenAIServingRender, adding the
        engine-aware checks (LoRA model validation, engine health).
91

92
        Returns:
93
            A list of engine_inputs on success, or an ErrorResponse on failure.
94
95
96
97
98
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

99
100
101
102
103
104
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

105
        return await self.openai_serving_render.render_completion(request)
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

        NOTE: Currently we do not support the following feature:
            - suffix (the language models we currently support do not support
            suffix)
        """
121
122
123
124
125
        if request.stream and request.use_beam_search:
            return self.create_error_response(
                "Streaming is not currently supported with beam search"
            )

126
127
128
129
        result = await self.render_completion_request(request)
        if isinstance(result, ErrorResponse):
            return result

130
        engine_inputs = result
131
132
133
134
135
136
137
138

        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
        created_time = int(time.time())

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

139
        lora_request = self._maybe_get_adapters(request)
140

141
142
143
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

144
        # Schedule the request and get the result generator.
145
        max_model_len = self.model_config.max_model_len
146
        generators: list[AsyncGenerator[RequestOutput, None]] = []
147
        for i, engine_input in enumerate(engine_inputs):
148
149
150
            max_tokens = get_max_tokens(
                max_model_len,
                request.max_tokens,
151
                self._extract_prompt_len(engine_input),
152
153
154
155
156
157
158
159
160
161
162
163
                self.default_sampling_params,
                self.override_max_tokens,
            )

            sampling_params: SamplingParams | BeamSearchParams
            if request.use_beam_search:
                sampling_params = request.to_beam_search_params(
                    max_tokens, self.default_sampling_params
                )
            else:
                sampling_params = request.to_sampling_params(
                    max_tokens,
164
                    self.default_sampling_params,
165
                )
166

167
            request_id_item = f"{request_id}-{i}"
168

169
170
            self._log_inputs(
                request_id_item,
171
                engine_input,
172
173
174
                params=sampling_params,
                lora_request=lora_request,
            )
175

176
177
178
179
180
181
182
183
            trace_headers = (
                None
                if raw_request is None
                else await self._get_trace_headers(raw_request.headers)
            )

            if isinstance(sampling_params, BeamSearchParams):
                generator = self.beam_search(
184
                    prompt=engine_input,
185
                    request_id=request_id,
186
187
                    params=sampling_params,
                    lora_request=lora_request,
188
                    trace_headers=trace_headers,
189
                )
190
191
            else:
                generator = self.engine_client.generate(
192
                    engine_input,
193
194
195
196
197
198
                    sampling_params,
                    request_id_item,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=request.priority,
                    data_parallel_rank=data_parallel_rank,
199
                )
200

201
            generators.append(generator)
202

203
        result_generator = merge_async_iterators(*generators)
204

205
        model_name = self.models.model_name(lora_request)
206
        num_prompts = len(engine_inputs)
207

208
        # Streaming response
209
210
        tokenizer = self.renderer.tokenizer

211
        if request.stream:
212
213
            return self.completion_stream_generator(
                request,
214
                engine_inputs,
215
216
217
218
                result_generator,
                request_id,
                created_time,
                model_name,
219
                num_prompts=num_prompts,
220
                tokenizer=tokenizer,
221
                request_metadata=request_metadata,
222
            )
223
224

        # Non-streaming response
225
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
226
227
228
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
229
230
231
232
233
234
235
236

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
237
                    final_res.prompt = self._extract_prompt_text(engine_inputs[i])
238

239
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
240

241
            response = self.request_output_to_completion_response(
242
243
244
245
246
247
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
248
                request_metadata,
249
            )
250
251
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
252

253
254
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
255
        if request.stream:
256
            response_json = response.model_dump_json()
257
258
259
260
261
262
263
264

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
265
266
267
268

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
269
        engine_inputs: list[EngineInput],
270
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
271
272
273
274
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
275
        tokenizer: TokenizerLike | None,
276
        request_metadata: RequestResponseMetadata,
277
    ) -> AsyncGenerator[str, None]:
278
        num_choices = 1 if request.n is None else request.n
279
        previous_text_lens = [0] * num_choices * num_prompts
280
281
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
282
        num_prompt_tokens = [0] * num_prompts
283
284
        num_cached_tokens = None
        first_iteration = True
285

286
        stream_options = request.stream_options
287
288
289
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
290

291
292
        try:
            async for prompt_idx, res in result_generator:
293
294
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
295

296
297
298
299
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

300
301
                prompt_text = res.prompt
                if prompt_text is None:
302
303
                    engine_input = engine_inputs[prompt_idx]
                    prompt_text = self._extract_prompt_text(engine_input)
304

305
                # Prompt details are excluded from later streamed outputs
306
307
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
308

309
                delta_token_ids: GenericSequence[int]
310
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
311
312

                for output in res.outputs:
313
                    i = output.index + prompt_idx * num_choices
314

315
316
317
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
318
                    prompt_token_ids_to_return: list[int] | None = None
319

320
                    assert request.max_tokens is not None
321
                    if request.echo and not has_echoed[i]:
322
                        assert prompt_token_ids is not None
323
324
                        if request.return_token_ids:
                            prompt_text = ""
325
                        assert prompt_text is not None
326
327
328
329
330
331
332
333
334
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
335
336
                                *prompt_token_ids,
                                *output.token_ids,
337
338
                            ]
                            out_logprobs = [
339
                                *(prompt_logprobs or []),
340
341
                                *(output.logprobs or []),
                            ]
342
                        prompt_token_ids_to_return = prompt_token_ids
343
344
345
                        has_echoed[i] = True
                    else:
                        # return just the delta
346
347
348
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
349

350
351
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
352
                        if not has_echoed[i] and request.return_token_ids:
353
354
355
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

356
357
358
359
360
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
361
362
363
                            # Chunked prefill case, don't return empty chunks
                            continue

364
                    if request.logprobs is not None:
365
                        assert out_logprobs is not None, "Did not output logprobs"
366
                        logprobs = self._create_completion_logprobs(
367
                            token_ids=delta_token_ids,
368
                            top_logprobs=out_logprobs,
369
                            num_output_top_logprobs=request.logprobs,
370
                            tokenizer=tokenizer,
371
                            initial_text_offset=previous_text_lens[i],
372
                            return_as_token_id=request.return_tokens_as_token_ids,
373
374
375
376
                        )
                    else:
                        logprobs = None

377
378
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
379
                    finish_reason = output.finish_reason
380
                    stop_reason = output.stop_reason
381

382
383
                    self._raise_if_error(finish_reason, request_id)

384
                    chunk = CompletionStreamResponse(
385
386
387
388
389
390
391
392
393
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
394
                                stop_reason=stop_reason,
395
                                prompt_token_ids=prompt_token_ids_to_return,
396
397
398
399
400
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
401
                            )
402
403
                        ],
                    )
404
405
406
407
408
409
410
411
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
412

413
                    response_json = chunk.model_dump_json(exclude_unset=False)
414
                    yield f"data: {response_json}\n\n"
415

416
417
418
419
420
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
421
422
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
423

424
425
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
426
427
                    cached_tokens=num_cached_tokens
                )
428

429
            if include_usage:
430
431
432
433
434
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
435
                    usage=final_usage_info,
436
                )
437
                final_usage_data = final_usage_chunk.model_dump_json(
438
439
                    exclude_unset=False, exclude_none=True
                )
440
441
                yield f"data: {final_usage_data}\n\n"

442
            # report to FastAPI middleware aggregate usage across all choices
443
            request_metadata.final_usage_info = final_usage_info
444

445
446
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
447
        except Exception as e:
448
            logger.exception("Error in completion stream generator.")
449
            data = self.create_streaming_error_response(e)
450
451
452
453
454
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
455
        final_res_batch: list[RequestOutput],
456
457
458
459
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
460
        tokenizer: TokenizerLike | None,
461
        request_metadata: RequestResponseMetadata,
462
    ) -> CompletionResponse:
463
        choices: list[CompletionResponseChoice] = []
464
465
        num_prompt_tokens = 0
        num_generated_tokens = 0
466
467
        kv_transfer_params = None
        last_final_res = None
468
        for final_res in final_res_batch:
469
            last_final_res = final_res
470
            prompt_token_ids = final_res.prompt_token_ids
471
            assert prompt_token_ids is not None
472
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
473
474
            prompt_text = final_res.prompt

475
            token_ids: GenericSequence[int]
476
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
477

478
            for output in final_res.outputs:
479
480
                self._raise_if_error(output.finish_reason, request_id)

481
                assert request.max_tokens is not None
482
                if request.echo:
483
484
                    if request.return_token_ids:
                        prompt_text = ""
485
                    assert prompt_text is not None
486
487
488
489
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
490
                    else:
491
492
493
494
495
496
497
498
499
500
501
502
503
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
504
505
                else:
                    token_ids = output.token_ids
506
                    out_logprobs = output.logprobs
507
508
509
                    output_text = output.text

                if request.logprobs is not None:
510
                    assert out_logprobs is not None, "Did not output logprobs"
511
                    logprobs = self._create_completion_logprobs(
512
                        token_ids=token_ids,
513
                        top_logprobs=out_logprobs,
514
                        tokenizer=tokenizer,
515
                        num_output_top_logprobs=request.logprobs,
516
                        return_as_token_id=request.return_tokens_as_token_ids,
517
518
519
520
521
522
523
524
525
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
526
                    stop_reason=output.stop_reason,
527
                    prompt_logprobs=final_res.prompt_logprobs,
528
529
530
531
532
533
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
534
535
536
                )
                choices.append(choice_data)

537
538
                num_generated_tokens += len(output.token_ids)

539
540
541
542
543
544
545
546
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

547
548
549
550
551
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
552
            usage.prompt_tokens_details = PromptTokenUsageInfo(
553
554
                cached_tokens=last_final_res.num_cached_tokens
            )
555

556
        request_metadata.final_usage_info = usage
557
558
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
559
560
561
562
563
564
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
565
566
            kv_transfer_params=kv_transfer_params,
        )
567
568
569
570

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
571
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
572
        num_output_top_logprobs: int,
573
        tokenizer: TokenizerLike | None,
574
        initial_text_offset: int = 0,
575
        return_as_token_id: bool | None = None,
576
577
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
578
        out_text_offset: list[int] = []
579
        out_token_logprobs: list[float | None] = []
580
        out_tokens: list[str] = []
581
        out_top_logprobs: list[dict[str, float] | None] = []
582
583
584

        last_token_len = 0

585
586
587
588
589
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
590
591
592
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
593
                if should_return_as_token_id:
594
                    token = f"token_id:{token_id}"
595
596
                else:
                    if tokenizer is None:
597
598
599
600
601
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
602
603
604
                        )

                    token = tokenizer.decode(token_id)
605

606
607
608
609
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
610
611
                step_token = step_top_logprobs[token_id]

612
                token = self._get_decoded_token(
613
                    step_token,
614
615
                    token_id,
                    tokenizer,
616
                    return_as_token_id=should_return_as_token_id,
617
618
619
                )
                token_logprob = max(step_token.logprob, -9999.0)

620
621
622
623
624
625
626
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
627
628
629
630
631
632
633
634
635
636
637
638
639
640
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
641
642
643
644
645
646
647
648
649
650
651
652
653

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )