serving_completion.py 19.7 KB
Newer Older
1
import time
2
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
3
4
                    Optional)
from typing import Sequence as GenericSequence
5
from typing import Tuple, cast
6

7
from fastapi import Request
8
from transformers import PreTrainedTokenizer
9

10
from vllm.config import ModelConfig
11
from vllm.engine.async_llm_engine import AsyncLLMEngine
12
from vllm.entrypoints.logger import RequestLogger
13
# yapf conflicts with isort for this block
14
15
16
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
17
18
19
20
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
21
                                              UsageInfo)
22
# yapf: enable
23
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
24
25
                                                    OpenAIServing,
                                                    PromptAdapterPath)
26
from vllm.logger import init_logger
27
28
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
29
from vllm.outputs import RequestOutput
30
from vllm.sequence import Logprob
31
32
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
33
from vllm.utils import merge_async_iterators, random_uuid
34
35
36

logger = init_logger(__name__)

Simon Mo's avatar
Simon Mo committed
37
38
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
39
TypeCreateLogProbsFn = Callable[
40
    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
41

42
43
44

class OpenAIServingCompletion(OpenAIServing):

45
46
47
48
49
50
51
52
53
54
    def __init__(
        self,
        engine: AsyncLLMEngine,
        model_config: ModelConfig,
        served_model_names: List[str],
        *,
        lora_modules: Optional[List[LoRAModulePath]],
        prompt_adapters: Optional[List[PromptAdapterPath]],
        request_logger: Optional[RequestLogger],
    ):
55
        super().__init__(engine=engine,
56
                         model_config=model_config,
57
                         served_model_names=served_model_names,
58
                         lora_modules=lora_modules,
59
60
                         prompt_adapters=prompt_adapters,
                         request_logger=request_logger)
61
62
63
64
65
66
67
68

    async def create_completion(self, request: CompletionRequest,
                                raw_request: Request):
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

69
        NOTE: Currently we do not support the following feature:
70
71
72
73
74
75
76
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

77
        # Return error for unsupported features.
78
79
80
81
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

82
        model_name = self.served_model_names[0]
83
        request_id = f"cmpl-{random_uuid()}"
84
        created_time = int(time.time())
85

86
        # Schedule the request and get the result generator.
87
        generators: List[AsyncIterator[RequestOutput]] = []
88
        try:
89
90
91
92
93
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

94
95
96
            tokenizer = await self.engine.get_tokenizer(lora_request)

            sampling_params = request.to_sampling_params()
97
            decoding_config = await self.engine.get_decoding_config()
98
99
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
100
            guided_decode_logit_processor = (
101
102
103
                await
                get_guided_decoding_logits_processor(guided_decoding_backend,
                                                     request, tokenizer))
104
105
106
107
108
            if guided_decode_logit_processor is not None:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logit_processor)
109

110
111
            prompts = list(
                self._tokenize_prompt_input_or_inputs(
112
113
                    request,
                    tokenizer,
114
                    request.prompt,
115
116
                    truncate_prompt_tokens=sampling_params.
                    truncate_prompt_tokens,
117
118
119
120
121
122
123
124
125
126
127
                    add_special_tokens=request.add_special_tokens,
                ))

            for i, prompt_inputs in enumerate(prompts):
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
                                 prompt_inputs,
                                 params=sampling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)
128

129
130
131
132
133
134
135
136
                is_tracing_enabled = await self.engine.is_tracing_enabled()
                trace_headers = None
                if is_tracing_enabled:
                    trace_headers = extract_trace_headers(raw_request.headers)
                if not is_tracing_enabled and contains_trace_headers(
                        raw_request.headers):
                    log_tracing_disabled_warning()

137
                generator = self.engine.generate(
138
                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
139
                    sampling_params,
140
                    request_id_item,
141
                    lora_request=lora_request,
142
                    prompt_adapter_request=prompt_adapter_request,
143
                    trace_headers=trace_headers,
144
145
146
                )

                generators.append(generator)
147
        except ValueError as e:
148
            # TODO: Use a vllm-specific Validation Error
149
            return self.create_error_response(str(e))
150

Simon Mo's avatar
Simon Mo committed
151
        result_generator: AsyncIterator[Tuple[
152
153
            int, RequestOutput]] = merge_async_iterators(*generators)

154
        # Similar to the OpenAI API, when n != best_of, we do not stream the
155
156
        # results. In addition, we do not stream the results when use
        # beam search.
157
158
159
160
161
162
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
163
164
165
166
167
168
            return self.completion_stream_generator(request,
                                                    raw_request,
                                                    result_generator,
                                                    request_id,
                                                    created_time,
                                                    model_name,
169
170
                                                    num_prompts=len(prompts),
                                                    tokenizer=tokenizer)
171
172

        # Non-streaming response
173
        final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
174
175
176
177
178
179
180
        try:
            async for i, res in result_generator:
                if await raw_request.is_disconnected():
                    # Abort the request if the client disconnects.
                    await self.engine.abort(f"{request_id}-{i}")
                    return self.create_error_response("Client disconnected")
                final_res_batch[i] = res
181
182
183
184
185
186
187
188
189
190
191
192
193

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
                    final_res.prompt = prompts[i]["prompt"]

            final_res_batch_checked = cast(List[RequestOutput],
                                           final_res_batch)

194
            response = self.request_output_to_completion_response(
195
196
197
198
199
200
201
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
            )
202
203
204
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
205

206
207
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
208
        if request.stream:
209
            response_json = response.model_dump_json()
210
211
212
213
214
215
216
217

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
218
219
220
221
222
223
224
225
226
227

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        raw_request: Request,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
228
        tokenizer: PreTrainedTokenizer,
229
    ) -> AsyncGenerator[str, None]:
230
231
232
233
        num_choices = 1 if request.n is None else request.n
        previous_texts = [""] * num_choices * num_prompts
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
234
235
236
237
238
239
240
241
242
243

        try:
            async for prompt_idx, res in result_generator:

                # Abort the request if the client disconnects.
                if await raw_request.is_disconnected():
                    await self.engine.abort(f"{request_id}-{prompt_idx}")
                    raise StopAsyncIteration()

                for output in res.outputs:
244
                    i = output.index + prompt_idx * num_choices
245
246
                    # TODO(simon): optimize the performance by avoiding full
                    # text O(n^2) sending.
247

248
                    assert request.max_tokens is not None
249
250
251
252
                    if request.echo and request.max_tokens == 0:
                        # only return the prompt
                        delta_text = res.prompt
                        delta_token_ids = res.prompt_token_ids
253
                        out_logprobs = res.prompt_logprobs
254
                        has_echoed[i] = True
255
256
                    elif (request.echo and request.max_tokens > 0
                          and not has_echoed[i]):
257
258
                        # echo the prompt and first token
                        delta_text = res.prompt + output.text
259
260
                        delta_token_ids = (res.prompt_token_ids +
                                           output.token_ids)
261
                        out_logprobs = res.prompt_logprobs + (output.logprobs
262
263
264
265
266
267
268
                                                              or [])
                        has_echoed[i] = True
                    else:
                        # return just the delta
                        delta_text = output.text[len(previous_texts[i]):]
                        delta_token_ids = output.token_ids[
                            previous_num_tokens[i]:]
269
                        out_logprobs = output.logprobs[previous_num_tokens[
270
271
272
                            i]:] if output.logprobs else None

                    if request.logprobs is not None:
273
274
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
275
                        logprobs = self._create_completion_logprobs(
276
                            token_ids=delta_token_ids,
277
                            top_logprobs=out_logprobs,
278
                            num_output_top_logprobs=request.logprobs,
279
                            tokenizer=tokenizer,
280
281
282
283
284
285
286
287
                            initial_text_offset=len(previous_texts[i]),
                        )
                    else:
                        logprobs = None

                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
                    finish_reason = output.finish_reason
288
                    stop_reason = output.stop_reason
289
290

                    chunk = CompletionStreamResponse(
291
292
293
294
295
296
297
298
299
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
300
                                stop_reason=stop_reason,
301
                            )
302
303
304
                        ])
                    if (request.stream_options
                            and request.stream_options.include_usage):
305
306
307
308
309
310
311
312
313
314
315
316
317
                        if (request.stream_options.continuous_usage_stats
                                or output.finish_reason is not None):
                            prompt_tokens = len(res.prompt_token_ids)
                            completion_tokens = len(output.token_ids)
                            usage = UsageInfo(
                                prompt_tokens=prompt_tokens,
                                completion_tokens=completion_tokens,
                                total_tokens=prompt_tokens + completion_tokens,
                            )
                        if request.stream_options.continuous_usage_stats:
                            chunk.usage = usage
                        else:
                            chunk.usage = None
318

319
                    response_json = chunk.model_dump_json(exclude_unset=False)
320
                    yield f"data: {response_json}\n\n"
321
322
323
324
325
326
327
328

            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
329
                    usage=usage,
330
331
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
332
                    exclude_unset=False, exclude_none=True))
333
334
                yield f"data: {final_usage_data}\n\n"

335
336
337
338
339
340
341
342
343
344
345
346
347
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
348
        tokenizer: PreTrainedTokenizer,
349
    ) -> CompletionResponse:
350
        choices: List[CompletionResponseChoice] = []
351
352
        num_prompt_tokens = 0
        num_generated_tokens = 0
353

354
355
356
357
358
359
        for final_res in final_res_batch:
            prompt_token_ids = final_res.prompt_token_ids
            prompt_logprobs = final_res.prompt_logprobs
            prompt_text = final_res.prompt

            for output in final_res.outputs:
360
                assert request.max_tokens is not None
361
362
                if request.echo and request.max_tokens == 0:
                    token_ids = prompt_token_ids
363
                    out_logprobs = prompt_logprobs
364
365
                    output_text = prompt_text
                elif request.echo and request.max_tokens > 0:
366
                    token_ids = prompt_token_ids + list(output.token_ids)
367
                    out_logprobs = (prompt_logprobs + output.logprobs
368
                                    if request.logprobs is not None else None)
369
370
371
                    output_text = prompt_text + output.text
                else:
                    token_ids = output.token_ids
372
                    out_logprobs = output.logprobs
373
374
375
                    output_text = output.text

                if request.logprobs is not None:
376
                    assert out_logprobs is not None, "Did not output logprobs"
377
                    logprobs = self._create_completion_logprobs(
378
                        token_ids=token_ids,
379
                        top_logprobs=out_logprobs,
380
                        tokenizer=tokenizer,
381
382
383
384
385
386
387
388
389
390
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
391
                    stop_reason=output.stop_reason,
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                )
                choices.append(choice_data)

            num_prompt_tokens += len(prompt_token_ids)
            num_generated_tokens += sum(
                len(output.token_ids) for output in final_res.outputs)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
412
413
414
415
416
417

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
418
        tokenizer: PreTrainedTokenizer,
419
420
421
422
423
424
425
426
427
428
429
430
431
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
432
                token = tokenizer.decode(token_id)
433
434
435
436
437
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
                token = self._get_decoded_token(step_top_logprobs[token_id],
438
                                                token_id, tokenizer)
439
440
441
442
443
444
445
446
447
448
449
450
                token_logprob = max(step_top_logprobs[token_id].logprob,
                                    -9999.0)
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
451
                    self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )