serving_completion.py 18.5 KB
Newer Older
1
import time
2
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
3
4
5
                    Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
6

7
from fastapi import Request
8

9
from vllm.config import ModelConfig
10
from vllm.engine.async_llm_engine import AsyncLLMEngine
11
# yapf conflicts with isort for this block
12
13
14
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
15
16
17
18
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
19
                                              UsageInfo)
20
21
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
22
from vllm.logger import init_logger
23
24
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
25
from vllm.outputs import RequestOutput
26
from vllm.sequence import Logprob
27
from vllm.utils import merge_async_iterators, random_uuid
28
29
30

logger = init_logger(__name__)

Simon Mo's avatar
Simon Mo committed
31
32
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
33
TypeCreateLogProbsFn = Callable[
34
    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
35

36

Simon Mo's avatar
Simon Mo committed
37
def parse_prompt_format(prompt) -> Tuple[bool, list]:
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    # get the prompt, openai supports the following
    # "a string, array of strings, array of tokens, or array of token arrays."
    prompt_is_tokens = False
    prompts = [prompt]  # case 1: a string
    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")
        elif isinstance(prompt[0], str):
            prompt_is_tokens = False
            prompts = prompt  # case 2: array of strings
        elif isinstance(prompt[0], int):
            prompt_is_tokens = True
            prompts = [prompt]  # case 3: array of tokens
        elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
            prompt_is_tokens = True
            prompts = prompt  # case 4: array of token arrays
        else:
55
56
            raise ValueError("prompt must be a string, array of strings, "
                             "array of tokens, or array of token arrays")
57
58
59
    return prompt_is_tokens, prompts


60
61
class OpenAIServingCompletion(OpenAIServing):

62
    def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
63
                 served_model_names: List[str],
64
                 lora_modules: Optional[List[LoRAModulePath]]):
65
        super().__init__(engine=engine,
66
                         model_config=model_config,
67
                         served_model_names=served_model_names,
68
                         lora_modules=lora_modules)
69
70
71
72
73
74
75
76

    async def create_completion(self, request: CompletionRequest,
                                raw_request: Request):
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

77
        NOTE: Currently we do not support the following feature:
78
79
80
81
82
83
84
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

85
        # Return error for unsupported features.
86
87
88
89
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

90
        model_name = self.served_model_names[0]
91
        request_id = f"cmpl-{random_uuid()}"
92
        created_time = int(time.time())
93

94
        # Schedule the request and get the result generator.
95
        generators: List[AsyncIterator[RequestOutput]] = []
96
97
        try:
            sampling_params = request.to_sampling_params()
98
            lora_request = self._maybe_get_lora(request)
99
            decoding_config = await self.engine.get_decoding_config()
100
101
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
102
103
            guided_decode_logit_processor = (
                await get_guided_decoding_logits_processor(
104
105
                    guided_decoding_backend, request, await
                    self.engine.get_tokenizer()))
106
107
108
109
110
            if guided_decode_logit_processor is not None:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logit_processor)
111
            prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
112

113
114
            for i, prompt in enumerate(prompts):
                if prompt_is_tokens:
115
                    prompt_formats = self._validate_prompt_and_tokenize(
116
117
118
119
                        request,
                        prompt_ids=prompt,
                        truncate_prompt_tokens=sampling_params.
                        truncate_prompt_tokens)
120
                else:
121
                    prompt_formats = self._validate_prompt_and_tokenize(
122
123
124
125
                        request,
                        prompt=prompt,
                        truncate_prompt_tokens=sampling_params.
                        truncate_prompt_tokens)
126
                prompt_ids, prompt_text = prompt_formats
127

128
129
130
131
132
133
134
135
136
137
138
                generator = self.engine.generate(
                    {
                        "prompt": prompt_text,
                        "prompt_token_ids": prompt_ids
                    },
                    sampling_params,
                    f"{request_id}-{i}",
                    lora_request=lora_request,
                )

                generators.append(generator)
139
        except ValueError as e:
140
            # TODO: Use a vllm-specific Validation Error
141
            return self.create_error_response(str(e))
142

Simon Mo's avatar
Simon Mo committed
143
        result_generator: AsyncIterator[Tuple[
144
145
            int, RequestOutput]] = merge_async_iterators(*generators)

146
        # Similar to the OpenAI API, when n != best_of, we do not stream the
147
148
        # results. In addition, we do not stream the results when use
        # beam search.
149
150
151
152
153
154
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
155
156
157
158
159
160
161
            return self.completion_stream_generator(request,
                                                    raw_request,
                                                    result_generator,
                                                    request_id,
                                                    created_time,
                                                    model_name,
                                                    num_prompts=len(prompts))
162
163

        # Non-streaming response
164
        final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
165
166
167
168
169
170
171
172
173
174
175
176
        try:
            async for i, res in result_generator:
                if await raw_request.is_disconnected():
                    # Abort the request if the client disconnects.
                    await self.engine.abort(f"{request_id}-{i}")
                    return self.create_error_response("Client disconnected")
                final_res_batch[i] = res
            response = self.request_output_to_completion_response(
                final_res_batch, request, request_id, created_time, model_name)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
177

178
179
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
180
        if request.stream:
181
            response_json = response.model_dump_json()
182
183
184
185
186
187
188
189

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
190
191
192
193
194
195
196
197
198
199
200

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        raw_request: Request,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
    ) -> AsyncGenerator[str, None]:
201
        assert request.n is not None
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        previous_texts = [""] * request.n * num_prompts
        previous_num_tokens = [0] * request.n * num_prompts
        has_echoed = [False] * request.n * num_prompts

        try:
            async for prompt_idx, res in result_generator:

                # Abort the request if the client disconnects.
                if await raw_request.is_disconnected():
                    await self.engine.abort(f"{request_id}-{prompt_idx}")
                    raise StopAsyncIteration()

                for output in res.outputs:
                    i = output.index + prompt_idx * request.n
216
217
                    # TODO(simon): optimize the performance by avoiding full
                    # text O(n^2) sending.
218

219
                    assert request.max_tokens is not None
220
221
222
223
                    if request.echo and request.max_tokens == 0:
                        # only return the prompt
                        delta_text = res.prompt
                        delta_token_ids = res.prompt_token_ids
224
                        out_logprobs = res.prompt_logprobs
225
                        has_echoed[i] = True
226
227
                    elif (request.echo and request.max_tokens > 0
                          and not has_echoed[i]):
228
229
                        # echo the prompt and first token
                        delta_text = res.prompt + output.text
230
231
                        delta_token_ids = (res.prompt_token_ids +
                                           output.token_ids)
232
                        out_logprobs = res.prompt_logprobs + (output.logprobs
233
234
235
236
237
238
239
                                                              or [])
                        has_echoed[i] = True
                    else:
                        # return just the delta
                        delta_text = output.text[len(previous_texts[i]):]
                        delta_token_ids = output.token_ids[
                            previous_num_tokens[i]:]
240
                        out_logprobs = output.logprobs[previous_num_tokens[
241
242
243
                            i]:] if output.logprobs else None

                    if request.logprobs is not None:
244
245
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
246
                        logprobs = self._create_completion_logprobs(
247
                            token_ids=delta_token_ids,
248
                            top_logprobs=out_logprobs,
249
250
251
252
253
254
255
256
257
                            num_output_top_logprobs=request.logprobs,
                            initial_text_offset=len(previous_texts[i]),
                        )
                    else:
                        logprobs = None

                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
                    finish_reason = output.finish_reason
258
                    stop_reason = output.stop_reason
259
260
261
262
263
264
265
266
267
268
                    if output.finish_reason is not None:  # return final usage
                        prompt_tokens = len(res.prompt_token_ids)
                        completion_tokens = len(output.token_ids)
                        final_usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
                    else:
                        final_usage = None
269
270

                    chunk = CompletionStreamResponse(
271
272
273
274
275
276
277
278
279
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
280
                                stop_reason=stop_reason,
281
                            )
282
283
284
285
286
287
                        ])
                    if (request.stream_options
                            and request.stream_options.include_usage):
                        chunk.usage = None

                    response_json = chunk.model_dump_json(exclude_unset=True)
288
                    yield f"data: {response_json}\n\n"
289
290
291
292
293
294
295
296
297
298
299
300
301
302

            if (request.stream_options
                    and request.stream_options.include_usage):
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
                    usage=final_usage,
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
                    exclude_unset=True, exclude_none=True))
                yield f"data: {final_usage_data}\n\n"

303
304
305
306
307
308
309
310
311
312
313
314
315
316
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
    ) -> CompletionResponse:
317
        choices: List[CompletionResponseChoice] = []
318
319
320
321
322
323
324
325
326
        num_prompt_tokens = 0
        num_generated_tokens = 0
        for final_res in final_res_batch:
            assert final_res is not None
            prompt_token_ids = final_res.prompt_token_ids
            prompt_logprobs = final_res.prompt_logprobs
            prompt_text = final_res.prompt

            for output in final_res.outputs:
327
                assert request.max_tokens is not None
328
329
                if request.echo and request.max_tokens == 0:
                    token_ids = prompt_token_ids
330
                    out_logprobs = prompt_logprobs
331
332
333
                    output_text = prompt_text
                elif request.echo and request.max_tokens > 0:
                    token_ids = prompt_token_ids + output.token_ids
334
                    out_logprobs = (prompt_logprobs + output.logprobs
335
                                    if request.logprobs is not None else None)
336
337
338
                    output_text = prompt_text + output.text
                else:
                    token_ids = output.token_ids
339
                    out_logprobs = output.logprobs
340
341
342
                    output_text = output.text

                if request.logprobs is not None:
343
                    assert out_logprobs is not None, "Did not output logprobs"
344
                    logprobs = self._create_completion_logprobs(
345
                        token_ids=token_ids,
346
                        top_logprobs=out_logprobs,
347
348
349
350
351
352
353
354
355
356
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
357
                    stop_reason=output.stop_reason,
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
                )
                choices.append(choice_data)

            num_prompt_tokens += len(prompt_token_ids)
            num_generated_tokens += sum(
                len(output.token_ids) for output in final_res.outputs)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
                token = self.tokenizer.decode(token_id)
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
                token = self._get_decoded_token(step_top_logprobs[token_id],
                                                token_id)
                token_logprob = max(step_top_logprobs[token_id].logprob,
                                    -9999.0)
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
                    self._get_decoded_token(top_lp[1], top_lp[0]):
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )