serving_completion.py 20.4 KB
Newer Older
1
import time
2
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
3
4
5
                    Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
6

7
from fastapi import Request
8

9
from vllm.config import ModelConfig
10
from vllm.engine.async_llm_engine import AsyncLLMEngine
11
# yapf conflicts with isort for this block
12
13
14
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
15
16
17
18
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
19
20
21
22
23
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              TokenizeRequest,
                                              TokenizeResponse, UsageInfo)
# yapf: enable
24
25
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
26
from vllm.logger import init_logger
27
28
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
29
from vllm.outputs import RequestOutput
30
from vllm.sequence import Logprob
31
32
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
33
from vllm.utils import merge_async_iterators, random_uuid
34
35
36

logger = init_logger(__name__)

Simon Mo's avatar
Simon Mo committed
37
38
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
39
TypeCreateLogProbsFn = Callable[
40
    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
41

42

Simon Mo's avatar
Simon Mo committed
43
def parse_prompt_format(prompt) -> Tuple[bool, list]:
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    # get the prompt, openai supports the following
    # "a string, array of strings, array of tokens, or array of token arrays."
    prompt_is_tokens = False
    prompts = [prompt]  # case 1: a string
    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")
        elif isinstance(prompt[0], str):
            prompt_is_tokens = False
            prompts = prompt  # case 2: array of strings
        elif isinstance(prompt[0], int):
            prompt_is_tokens = True
            prompts = [prompt]  # case 3: array of tokens
        elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
            prompt_is_tokens = True
            prompts = prompt  # case 4: array of token arrays
        else:
61
62
            raise ValueError("prompt must be a string, array of strings, "
                             "array of tokens, or array of token arrays")
63
64
65
    return prompt_is_tokens, prompts


66
67
class OpenAIServingCompletion(OpenAIServing):

68
    def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
69
                 served_model_names: List[str],
70
                 lora_modules: Optional[List[LoRAModulePath]]):
71
        super().__init__(engine=engine,
72
                         model_config=model_config,
73
                         served_model_names=served_model_names,
74
                         lora_modules=lora_modules)
75
76
77
78
79
80
81
82

    async def create_completion(self, request: CompletionRequest,
                                raw_request: Request):
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

83
        NOTE: Currently we do not support the following feature:
84
85
86
87
88
89
90
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

91
        # Return error for unsupported features.
92
93
94
95
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

96
        model_name = self.served_model_names[0]
97
        request_id = f"cmpl-{random_uuid()}"
98
        created_time = int(time.time())
99

100
        # Schedule the request and get the result generator.
101
        generators: List[AsyncIterator[RequestOutput]] = []
102
103
        try:
            sampling_params = request.to_sampling_params()
104
            lora_request = self._maybe_get_lora(request)
105
            decoding_config = await self.engine.get_decoding_config()
106
107
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
108
109
            guided_decode_logit_processor = (
                await get_guided_decoding_logits_processor(
110
111
                    guided_decoding_backend, request, await
                    self.engine.get_tokenizer()))
112
113
114
115
116
            if guided_decode_logit_processor is not None:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logit_processor)
117
            prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
118

119
120
            for i, prompt in enumerate(prompts):
                if prompt_is_tokens:
121
                    prompt_formats = self._validate_prompt_and_tokenize(
122
123
124
125
                        request,
                        prompt_ids=prompt,
                        truncate_prompt_tokens=sampling_params.
                        truncate_prompt_tokens)
126
                else:
127
                    prompt_formats = self._validate_prompt_and_tokenize(
128
129
130
131
                        request,
                        prompt=prompt,
                        truncate_prompt_tokens=sampling_params.
                        truncate_prompt_tokens)
132
                prompt_ids, prompt_text = prompt_formats
133

134
135
136
137
138
139
140
141
                is_tracing_enabled = await self.engine.is_tracing_enabled()
                trace_headers = None
                if is_tracing_enabled:
                    trace_headers = extract_trace_headers(raw_request.headers)
                if not is_tracing_enabled and contains_trace_headers(
                        raw_request.headers):
                    log_tracing_disabled_warning()

142
143
144
145
146
147
148
149
                generator = self.engine.generate(
                    {
                        "prompt": prompt_text,
                        "prompt_token_ids": prompt_ids
                    },
                    sampling_params,
                    f"{request_id}-{i}",
                    lora_request=lora_request,
150
                    trace_headers=trace_headers,
151
152
153
                )

                generators.append(generator)
154
        except ValueError as e:
155
            # TODO: Use a vllm-specific Validation Error
156
            return self.create_error_response(str(e))
157

Simon Mo's avatar
Simon Mo committed
158
        result_generator: AsyncIterator[Tuple[
159
160
            int, RequestOutput]] = merge_async_iterators(*generators)

161
        # Similar to the OpenAI API, when n != best_of, we do not stream the
162
163
        # results. In addition, we do not stream the results when use
        # beam search.
164
165
166
167
168
169
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
170
171
172
173
174
175
176
            return self.completion_stream_generator(request,
                                                    raw_request,
                                                    result_generator,
                                                    request_id,
                                                    created_time,
                                                    model_name,
                                                    num_prompts=len(prompts))
177
178

        # Non-streaming response
179
        final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
180
181
182
183
184
185
186
187
188
189
190
191
        try:
            async for i, res in result_generator:
                if await raw_request.is_disconnected():
                    # Abort the request if the client disconnects.
                    await self.engine.abort(f"{request_id}-{i}")
                    return self.create_error_response("Client disconnected")
                final_res_batch[i] = res
            response = self.request_output_to_completion_response(
                final_res_batch, request, request_id, created_time, model_name)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
192

193
194
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
195
        if request.stream:
196
            response_json = response.model_dump_json()
197
198
199
200
201
202
203
204

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
205
206
207
208
209
210
211
212
213
214
215

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        raw_request: Request,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
    ) -> AsyncGenerator[str, None]:
216
        assert request.n is not None
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        previous_texts = [""] * request.n * num_prompts
        previous_num_tokens = [0] * request.n * num_prompts
        has_echoed = [False] * request.n * num_prompts

        try:
            async for prompt_idx, res in result_generator:

                # Abort the request if the client disconnects.
                if await raw_request.is_disconnected():
                    await self.engine.abort(f"{request_id}-{prompt_idx}")
                    raise StopAsyncIteration()

                for output in res.outputs:
                    i = output.index + prompt_idx * request.n
231
232
                    # TODO(simon): optimize the performance by avoiding full
                    # text O(n^2) sending.
233

234
                    assert request.max_tokens is not None
235
236
237
238
                    if request.echo and request.max_tokens == 0:
                        # only return the prompt
                        delta_text = res.prompt
                        delta_token_ids = res.prompt_token_ids
239
                        out_logprobs = res.prompt_logprobs
240
                        has_echoed[i] = True
241
242
                    elif (request.echo and request.max_tokens > 0
                          and not has_echoed[i]):
243
244
                        # echo the prompt and first token
                        delta_text = res.prompt + output.text
245
246
                        delta_token_ids = (res.prompt_token_ids +
                                           output.token_ids)
247
                        out_logprobs = res.prompt_logprobs + (output.logprobs
248
249
250
251
252
253
254
                                                              or [])
                        has_echoed[i] = True
                    else:
                        # return just the delta
                        delta_text = output.text[len(previous_texts[i]):]
                        delta_token_ids = output.token_ids[
                            previous_num_tokens[i]:]
255
                        out_logprobs = output.logprobs[previous_num_tokens[
256
257
258
                            i]:] if output.logprobs else None

                    if request.logprobs is not None:
259
260
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
261
                        logprobs = self._create_completion_logprobs(
262
                            token_ids=delta_token_ids,
263
                            top_logprobs=out_logprobs,
264
265
266
267
268
269
270
271
272
                            num_output_top_logprobs=request.logprobs,
                            initial_text_offset=len(previous_texts[i]),
                        )
                    else:
                        logprobs = None

                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
                    finish_reason = output.finish_reason
273
                    stop_reason = output.stop_reason
274
275

                    chunk = CompletionStreamResponse(
276
277
278
279
280
281
282
283
284
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
285
                                stop_reason=stop_reason,
286
                            )
287
288
289
                        ])
                    if (request.stream_options
                            and request.stream_options.include_usage):
290
291
292
293
294
295
296
297
298
299
300
301
302
                        if (request.stream_options.continuous_usage_stats
                                or output.finish_reason is not None):
                            prompt_tokens = len(res.prompt_token_ids)
                            completion_tokens = len(output.token_ids)
                            usage = UsageInfo(
                                prompt_tokens=prompt_tokens,
                                completion_tokens=completion_tokens,
                                total_tokens=prompt_tokens + completion_tokens,
                            )
                        if request.stream_options.continuous_usage_stats:
                            chunk.usage = usage
                        else:
                            chunk.usage = None
303

304
                    response_json = chunk.model_dump_json(exclude_unset=False)
305
                    yield f"data: {response_json}\n\n"
306
307
308
309
310
311
312
313

            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
314
                    usage=usage,
315
316
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
317
                    exclude_unset=False, exclude_none=True))
318
319
                yield f"data: {final_usage_data}\n\n"

320
321
322
323
324
325
326
327
328
329
330
331
332
333
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
    ) -> CompletionResponse:
334
        choices: List[CompletionResponseChoice] = []
335
336
337
338
339
340
341
342
343
        num_prompt_tokens = 0
        num_generated_tokens = 0
        for final_res in final_res_batch:
            assert final_res is not None
            prompt_token_ids = final_res.prompt_token_ids
            prompt_logprobs = final_res.prompt_logprobs
            prompt_text = final_res.prompt

            for output in final_res.outputs:
344
                assert request.max_tokens is not None
345
346
                if request.echo and request.max_tokens == 0:
                    token_ids = prompt_token_ids
347
                    out_logprobs = prompt_logprobs
348
349
                    output_text = prompt_text
                elif request.echo and request.max_tokens > 0:
350
                    token_ids = prompt_token_ids + list(output.token_ids)
351
                    out_logprobs = (prompt_logprobs + output.logprobs
352
                                    if request.logprobs is not None else None)
353
354
355
                    output_text = prompt_text + output.text
                else:
                    token_ids = output.token_ids
356
                    out_logprobs = output.logprobs
357
358
359
                    output_text = output.text

                if request.logprobs is not None:
360
                    assert out_logprobs is not None, "Did not output logprobs"
361
                    logprobs = self._create_completion_logprobs(
362
                        token_ids=token_ids,
363
                        top_logprobs=out_logprobs,
364
365
366
367
368
369
370
371
372
373
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
374
                    stop_reason=output.stop_reason,
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                )
                choices.append(choice_data)

            num_prompt_tokens += len(prompt_token_ids)
            num_generated_tokens += sum(
                len(output.token_ids) for output in final_res.outputs)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
                token = self.tokenizer.decode(token_id)
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
                token = self._get_decoded_token(step_top_logprobs[token_id],
                                                token_id)
                token_logprob = max(step_top_logprobs[token_id].logprob,
                                    -9999.0)
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
                    self._get_decoded_token(top_lp[1], top_lp[0]):
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

    async def create_tokenize(self,
                              request: TokenizeRequest) -> TokenizeResponse:
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        (input_ids, input_text) = self._validate_prompt_and_tokenize(
            request,
            prompt=request.prompt,
            add_special_tokens=request.add_special_tokens)

        return TokenizeResponse(tokens=input_ids,
                                count=len(input_ids),
                                max_model_len=self.max_model_len)

    async def create_detokenize(
            self, request: DetokenizeRequest) -> DetokenizeResponse:
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        (input_ids, input_text) = self._validate_prompt_and_tokenize(
            request, prompt_ids=request.tokens)

        return DetokenizeResponse(prompt=input_text)