serving_completion.py 17.8 KB
Newer Older
1
import time
2
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
3
4
5
                    Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
6

7
from fastapi import Request
8

9
from vllm.config import ModelConfig
10
from vllm.engine.async_llm_engine import AsyncLLMEngine
11
12
13
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
14
15
16
17
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
18
19
                                              UsageInfo)
# yapf: enable
20
21
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                    OpenAIServing)
22
from vllm.logger import init_logger
23
24
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
25
from vllm.outputs import RequestOutput
26
from vllm.sequence import Logprob
27
from vllm.utils import merge_async_iterators, random_uuid
28
29
30

logger = init_logger(__name__)

Simon Mo's avatar
Simon Mo committed
31
32
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
33
TypeCreateLogProbsFn = Callable[
34
    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
35

36

Simon Mo's avatar
Simon Mo committed
37
def parse_prompt_format(prompt) -> Tuple[bool, list]:
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    # get the prompt, openai supports the following
    # "a string, array of strings, array of tokens, or array of token arrays."
    prompt_is_tokens = False
    prompts = [prompt]  # case 1: a string
    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")
        elif isinstance(prompt[0], str):
            prompt_is_tokens = False
            prompts = prompt  # case 2: array of strings
        elif isinstance(prompt[0], int):
            prompt_is_tokens = True
            prompts = [prompt]  # case 3: array of tokens
        elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
            prompt_is_tokens = True
            prompts = prompt  # case 4: array of token arrays
        else:
55
56
            raise ValueError("prompt must be a string, array of strings, "
                             "array of tokens, or array of token arrays")
57
58
59
    return prompt_is_tokens, prompts


60
61
class OpenAIServingCompletion(OpenAIServing):

62
    def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
63
                 served_model_names: List[str],
64
                 lora_modules: Optional[List[LoRAModulePath]]):
65
        super().__init__(engine=engine,
66
                         model_config=model_config,
67
                         served_model_names=served_model_names,
68
                         lora_modules=lora_modules)
69
70
71
72
73
74
75
76

    async def create_completion(self, request: CompletionRequest,
                                raw_request: Request):
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

77
        NOTE: Currently we do not support the following feature:
78
79
80
81
82
83
84
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

85
        # Return error for unsupported features.
86
87
88
89
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

90
        model_name = self.served_model_names[0]
91
        request_id = f"cmpl-{random_uuid()}"
92
        created_time = int(time.time())
93

94
        # Schedule the request and get the result generator.
95
        generators: List[AsyncIterator[RequestOutput]] = []
96
97
        try:
            sampling_params = request.to_sampling_params()
98
            lora_request = self._maybe_get_lora(request)
99
            decoding_config = await self.engine.get_decoding_config()
100
101
            guided_decoding_backend = request.guided_decoding_backend \
                or decoding_config.guided_decoding_backend
102
103
            guided_decode_logit_processor = (
                await get_guided_decoding_logits_processor(
104
105
                    guided_decoding_backend, request, await
                    self.engine.get_tokenizer()))
106
107
108
109
110
            if guided_decode_logit_processor is not None:
                if sampling_params.logits_processors is None:
                    sampling_params.logits_processors = []
                sampling_params.logits_processors.append(
                    guided_decode_logit_processor)
111
            prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
112

113
114
            for i, prompt in enumerate(prompts):
                if prompt_is_tokens:
115
                    prompt_formats = self._validate_prompt_and_tokenize(
116
117
118
119
                        request,
                        prompt_ids=prompt,
                        truncate_prompt_tokens=sampling_params.
                        truncate_prompt_tokens)
120
                else:
121
                    prompt_formats = self._validate_prompt_and_tokenize(
122
123
124
125
                        request,
                        prompt=prompt,
                        truncate_prompt_tokens=sampling_params.
                        truncate_prompt_tokens)
126
                prompt_ids, prompt_text = prompt_formats
127

128
129
130
131
132
133
134
135
136
137
138
                generator = self.engine.generate(
                    {
                        "prompt": prompt_text,
                        "prompt_token_ids": prompt_ids
                    },
                    sampling_params,
                    f"{request_id}-{i}",
                    lora_request=lora_request,
                )

                generators.append(generator)
139
        except ValueError as e:
140
            # TODO: Use a vllm-specific Validation Error
141
            return self.create_error_response(str(e))
142

Simon Mo's avatar
Simon Mo committed
143
        result_generator: AsyncIterator[Tuple[
144
145
            int, RequestOutput]] = merge_async_iterators(*generators)

146
        # Similar to the OpenAI API, when n != best_of, we do not stream the
147
148
        # results. In addition, we do not stream the results when use
        # beam search.
149
150
151
152
153
154
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)

        # Streaming response
        if stream:
155
156
157
158
159
160
161
            return self.completion_stream_generator(request,
                                                    raw_request,
                                                    result_generator,
                                                    request_id,
                                                    created_time,
                                                    model_name,
                                                    num_prompts=len(prompts))
162
163

        # Non-streaming response
164
        final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
165
166
167
168
169
170
171
172
173
174
175
176
        try:
            async for i, res in result_generator:
                if await raw_request.is_disconnected():
                    # Abort the request if the client disconnects.
                    await self.engine.abort(f"{request_id}-{i}")
                    return self.create_error_response("Client disconnected")
                final_res_batch[i] = res
            response = self.request_output_to_completion_response(
                final_res_batch, request, request_id, created_time, model_name)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
177

178
179
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
180
        if request.stream:
181
            response_json = response.model_dump_json()
182
183
184
185
186
187
188
189

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
190
191
192
193
194
195
196
197
198
199
200

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
        raw_request: Request,
        result_generator: AsyncIterator[Tuple[int, RequestOutput]],
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
    ) -> AsyncGenerator[str, None]:
201
        assert request.n is not None
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        previous_texts = [""] * request.n * num_prompts
        previous_num_tokens = [0] * request.n * num_prompts
        has_echoed = [False] * request.n * num_prompts

        try:
            async for prompt_idx, res in result_generator:

                # Abort the request if the client disconnects.
                if await raw_request.is_disconnected():
                    await self.engine.abort(f"{request_id}-{prompt_idx}")
                    raise StopAsyncIteration()

                for output in res.outputs:
                    i = output.index + prompt_idx * request.n
216
217
                    # TODO(simon): optimize the performance by avoiding full
                    # text O(n^2) sending.
218

219
                    assert request.max_tokens is not None
220
221
222
223
224
225
                    if request.echo and request.max_tokens == 0:
                        # only return the prompt
                        delta_text = res.prompt
                        delta_token_ids = res.prompt_token_ids
                        top_logprobs = res.prompt_logprobs
                        has_echoed[i] = True
226
227
                    elif (request.echo and request.max_tokens > 0
                          and not has_echoed[i]):
228
229
                        # echo the prompt and first token
                        delta_text = res.prompt + output.text
230
231
                        delta_token_ids = (res.prompt_token_ids +
                                           output.token_ids)
232
233
234
235
236
237
238
239
240
241
242
243
                        top_logprobs = res.prompt_logprobs + (output.logprobs
                                                              or [])
                        has_echoed[i] = True
                    else:
                        # return just the delta
                        delta_text = output.text[len(previous_texts[i]):]
                        delta_token_ids = output.token_ids[
                            previous_num_tokens[i]:]
                        top_logprobs = output.logprobs[previous_num_tokens[
                            i]:] if output.logprobs else None

                    if request.logprobs is not None:
244
                        logprobs = self._create_completion_logprobs(
245
246
247
248
249
250
251
252
253
254
255
                            token_ids=delta_token_ids,
                            top_logprobs=top_logprobs,
                            num_output_top_logprobs=request.logprobs,
                            initial_text_offset=len(previous_texts[i]),
                        )
                    else:
                        logprobs = None

                    previous_texts[i] = output.text
                    previous_num_tokens[i] = len(output.token_ids)
                    finish_reason = output.finish_reason
256
                    stop_reason = output.stop_reason
257
258
259
260
261
262
263
264
265
266
                    if output.finish_reason is not None:  # return final usage
                        prompt_tokens = len(res.prompt_token_ids)
                        completion_tokens = len(output.token_ids)
                        final_usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
                    else:
                        final_usage = None
267
268
269
270
271
272
273
274
275
276
                    response_json = CompletionStreamResponse(
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
277
                                stop_reason=stop_reason,
278
                            )
279
280
281
                        ],
                        usage=final_usage,
                    ).model_dump_json(exclude_unset=True)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
                    yield f"data: {response_json}\n\n"
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
        final_res_batch: List[RequestOutput],
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
    ) -> CompletionResponse:
297
        choices: List[CompletionResponseChoice] = []
298
299
300
301
302
303
304
305
306
        num_prompt_tokens = 0
        num_generated_tokens = 0
        for final_res in final_res_batch:
            assert final_res is not None
            prompt_token_ids = final_res.prompt_token_ids
            prompt_logprobs = final_res.prompt_logprobs
            prompt_text = final_res.prompt

            for output in final_res.outputs:
307
                assert request.max_tokens is not None
308
309
310
311
312
313
                if request.echo and request.max_tokens == 0:
                    token_ids = prompt_token_ids
                    top_logprobs = prompt_logprobs
                    output_text = prompt_text
                elif request.echo and request.max_tokens > 0:
                    token_ids = prompt_token_ids + output.token_ids
314
315
                    top_logprobs = (prompt_logprobs + output.logprobs
                                    if request.logprobs else None)
316
317
318
319
320
321
322
                    output_text = prompt_text + output.text
                else:
                    token_ids = output.token_ids
                    top_logprobs = output.logprobs
                    output_text = output.text

                if request.logprobs is not None:
323
324
325
                    assert top_logprobs is not None, (
                        "top_logprobs must be provided when logprobs "
                        "is requested")
326
                    logprobs = self._create_completion_logprobs(
327
328
329
330
331
332
333
334
335
336
337
338
                        token_ids=token_ids,
                        top_logprobs=top_logprobs,
                        num_output_top_logprobs=request.logprobs,
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
339
                    stop_reason=output.stop_reason,
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
                )
                choices.append(choice_data)

            num_prompt_tokens += len(prompt_token_ids)
            num_generated_tokens += sum(
                len(output.token_ids) for output in final_res.outputs)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
        num_output_top_logprobs: int,
        initial_text_offset: int = 0,
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
        out_text_offset: List[int] = []
        out_token_logprobs: List[Optional[float]] = []
        out_tokens: List[str] = []
        out_top_logprobs: List[Optional[Dict[str, float]]] = []

        last_token_len = 0

        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
                token = self.tokenizer.decode(token_id)
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
                token = self._get_decoded_token(step_top_logprobs[token_id],
                                                token_id)
                token_logprob = max(step_top_logprobs[token_id].logprob,
                                    -9999.0)
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
                    self._get_decoded_token(top_lp[1], top_lp[0]):
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )