serving.py 27.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
from vllm.entrypoints.openai.completion.protocol import (
16
17
18
19
20
21
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
22
23
)
from vllm.entrypoints.openai.engine.protocol import (
24
25
26
27
28
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
29
from vllm.entrypoints.openai.engine.serving import (
30
31
32
33
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
34
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
35
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
36
from vllm.exceptions import VLLMValidationError
37
from vllm.logger import init_logger
38
from vllm.logprobs import Logprob
39
from vllm.outputs import RequestOutput
40
from vllm.renderers.inputs import TokPrompt
41
from vllm.sampling_params import BeamSearchParams, SamplingParams
42
from vllm.tokenizers import TokenizerLike
43
44
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
45
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
46
47
48
49
50

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
51
52
    def __init__(
        self,
53
        engine_client: EngineClient,
54
        models: OpenAIServingModels,
55
        *,
56
        request_logger: RequestLogger | None,
57
        return_tokens_as_token_ids: bool = False,
58
        enable_prompt_tokens_details: bool = False,
59
        enable_force_include_usage: bool = False,
60
        log_error_stack: bool = False,
61
    ):
62
63
64
65
66
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
67
            log_error_stack=log_error_stack,
68
        )
69
70
71
72

        # set up logits processors
        self.logits_processors = self.model_config.logits_processors

73
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
74
        self.enable_force_include_usage = enable_force_include_usage
75
76

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
77

78
    async def render_completion_request(
79
80
        self,
        request: CompletionRequest,
81
    ) -> list[TokPrompt] | ErrorResponse:
82
83
        """
        render completion request by validating and preprocessing inputs.
84

85
86
87
        Returns:
            A list of engine_prompts on success,
            or an ErrorResponse on failure.
88
89
90
91
92
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

93
94
95
96
97
98
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

99
        # Return error for unsupported features.
100
        if request.suffix is not None:
101
            return self.create_error_response("suffix is not currently supported")
102

103
        if request.echo and request.prompt_embeds is not None:
104
            return self.create_error_response("Echo is unsupported with prompt embeds.")
105

106
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
107
            return self.create_error_response(
108
109
                "prompt_logprobs is not compatible with prompt embeds."
            )
110

111
        try:
112
113
114
            engine_prompts = await self._preprocess_completion(
                request,
                prompt_input=request.prompt,
115
                prompt_embeds=request.prompt_embeds,
116
            )
117
        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
118
            logger.exception("Error in preprocessing prompt inputs")
119
            return self.create_error_response(e)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        return engine_prompts

    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

        NOTE: Currently we do not support the following feature:
            - suffix (the language models we currently support do not support
            suffix)
        """
        result = await self.render_completion_request(request)
        if isinstance(result, ErrorResponse):
            return result

        engine_prompts = result

        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
        created_time = int(time.time())

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

        try:
            lora_request = self._maybe_get_adapters(request)
        except (ValueError, TypeError, RuntimeError) as e:
            logger.exception("Error preparing request components")
154
            return self.create_error_response(e)
155

156
157
158
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

159
        # Schedule the request and get the result generator.
160
        max_model_len = self.model_config.max_model_len
161
        generators: list[AsyncGenerator[RequestOutput, None]] = []
162
163
        try:
            for i, engine_prompt in enumerate(engine_prompts):
164
                prompt_text = self._extract_prompt_text(engine_prompt)
165
166

                max_tokens = get_max_tokens(
167
                    max_model_len,
168
                    request.max_tokens,
169
170
                    self._extract_prompt_len(engine_prompt),
                    self.default_sampling_params,
171
                )
172

173
                sampling_params: SamplingParams | BeamSearchParams
174
175
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
176
177
                        max_tokens, self.default_sampling_params
                    )
178
179
                else:
                    sampling_params = request.to_sampling_params(
180
181
182
183
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
184
185
186
187
                    validate_logits_processors_parameters(
                        self.logits_processors,
                        sampling_params,
                    )
188

189
190
                request_id_item = f"{request_id}-{i}"

191
192
                self._log_inputs(
                    request_id_item,
193
                    engine_prompt,
194
195
196
                    params=sampling_params,
                    lora_request=lora_request,
                )
197

198
199
200
201
202
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
203

204
                if isinstance(sampling_params, BeamSearchParams):
205
                    generator = self.beam_search(
206
                        prompt=engine_prompt,
207
208
                        request_id=request_id,
                        params=sampling_params,
209
                        lora_request=lora_request,
210
                        trace_headers=trace_headers,
211
                    )
212
                else:
213
214
215
216
                    tok_params = request.build_tok_params(self.model_config)
                    tokenization_kwargs = tok_params.get_encode_kwargs()

                    engine_request = self.input_processor.process_inputs(
217
218
219
220
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
221
                        tokenization_kwargs=tokenization_kwargs,
222
223
                        trace_headers=trace_headers,
                        priority=request.priority,
224
                        data_parallel_rank=data_parallel_rank,
225
                    )
226

227
                    generator = self.engine_client.generate(
228
                        engine_request,
229
230
231
232
233
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
234
235
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
236
                        data_parallel_rank=data_parallel_rank,
237
                    )
238
239

                generators.append(generator)
240
        except ValueError as e:
241
            return self.create_error_response(e)
242

243
        result_generator = merge_async_iterators(*generators)
244

245
        model_name = self.models.model_name(lora_request)
246
247
        num_prompts = len(engine_prompts)

248
249
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
250
251

        # Streaming response
252
253
        tokenizer = self.renderer.tokenizer

254
        if stream:
255
256
            return self.completion_stream_generator(
                request,
257
                engine_prompts,
258
259
260
261
                result_generator,
                request_id,
                created_time,
                model_name,
262
                num_prompts=num_prompts,
263
                tokenizer=tokenizer,
264
                request_metadata=request_metadata,
265
            )
266
267

        # Non-streaming response
268
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
269
270
271
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
272
273
274
275
276
277
278
279

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
280
                    engine_prompt = engine_prompts[i]
281
                    final_res.prompt = self._extract_prompt_text(engine_prompt)
282

283
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
284

285
            response = self.request_output_to_completion_response(
286
287
288
289
290
291
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
292
                request_metadata,
293
            )
294
295
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
296
297
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
298
        except ValueError as e:
299
            return self.create_error_response(e)
300

301
302
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
303
        if request.stream:
304
            response_json = response.model_dump_json()
305
306
307
308
309
310
311
312

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
313
314
315
316

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
317
        engine_prompts: list[TokPrompt],
318
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
319
320
321
322
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
323
        tokenizer: TokenizerLike | None,
324
        request_metadata: RequestResponseMetadata,
325
    ) -> AsyncGenerator[str, None]:
326
        num_choices = 1 if request.n is None else request.n
327
        previous_text_lens = [0] * num_choices * num_prompts
328
329
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
330
        num_prompt_tokens = [0] * num_prompts
331
332
        num_cached_tokens = None
        first_iteration = True
333

334
        stream_options = request.stream_options
335
336
337
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
338

339
340
        try:
            async for prompt_idx, res in result_generator:
341
342
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
343

344
345
346
347
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

348
349
350
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
351
                    prompt_text = self._extract_prompt_text(engine_prompt)
352

353
                # Prompt details are excluded from later streamed outputs
354
355
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
356

357
                delta_token_ids: GenericSequence[int]
358
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
359
360

                for output in res.outputs:
361
                    i = output.index + prompt_idx * num_choices
362

363
364
365
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
366
                    prompt_token_ids_to_return: list[int] | None = None
367

368
                    assert request.max_tokens is not None
369
                    if request.echo and not has_echoed[i]:
370
                        assert prompt_token_ids is not None
371
372
                        if request.return_token_ids:
                            prompt_text = ""
373
                        assert prompt_text is not None
374
375
376
377
378
379
380
381
382
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
383
384
                                *prompt_token_ids,
                                *output.token_ids,
385
386
                            ]
                            out_logprobs = [
387
                                *(prompt_logprobs or []),
388
389
                                *(output.logprobs or []),
                            ]
390
                        prompt_token_ids_to_return = prompt_token_ids
391
392
393
                        has_echoed[i] = True
                    else:
                        # return just the delta
394
395
396
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
397

398
399
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
400
                        if not has_echoed[i] and request.return_token_ids:
401
402
403
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

404
405
406
407
408
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
409
410
411
                            # Chunked prefill case, don't return empty chunks
                            continue

412
                    if request.logprobs is not None:
413
                        assert out_logprobs is not None, "Did not output logprobs"
414
                        logprobs = self._create_completion_logprobs(
415
                            token_ids=delta_token_ids,
416
                            top_logprobs=out_logprobs,
417
                            num_output_top_logprobs=request.logprobs,
418
                            tokenizer=tokenizer,
419
                            initial_text_offset=previous_text_lens[i],
420
                            return_as_token_id=request.return_tokens_as_token_ids,
421
422
423
424
                        )
                    else:
                        logprobs = None

425
426
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
427
                    finish_reason = output.finish_reason
428
                    stop_reason = output.stop_reason
429

430
431
                    self._raise_if_error(finish_reason, request_id)

432
                    chunk = CompletionStreamResponse(
433
434
435
436
437
438
439
440
441
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
442
                                stop_reason=stop_reason,
443
                                prompt_token_ids=prompt_token_ids_to_return,
444
445
446
447
448
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
449
                            )
450
451
                        ],
                    )
452
453
454
455
456
457
458
459
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
460

461
                    response_json = chunk.model_dump_json(exclude_unset=False)
462
                    yield f"data: {response_json}\n\n"
463

464
465
466
467
468
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
469
470
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
471

472
473
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
474
475
                    cached_tokens=num_cached_tokens
                )
476

477
            if include_usage:
478
479
480
481
482
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
483
                    usage=final_usage_info,
484
                )
485
                final_usage_data = final_usage_chunk.model_dump_json(
486
487
                    exclude_unset=False, exclude_none=True
                )
488
489
                yield f"data: {final_usage_data}\n\n"

490
            # report to FastAPI middleware aggregate usage across all choices
491
            request_metadata.final_usage_info = final_usage_info
492

493
494
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
495
        except Exception as e:
496
            logger.exception("Error in completion stream generator.")
497
            data = self.create_streaming_error_response(e)
498
499
500
501
502
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
503
        final_res_batch: list[RequestOutput],
504
505
506
507
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
508
        tokenizer: TokenizerLike | None,
509
        request_metadata: RequestResponseMetadata,
510
    ) -> CompletionResponse:
511
        choices: list[CompletionResponseChoice] = []
512
513
        num_prompt_tokens = 0
        num_generated_tokens = 0
514
515
        kv_transfer_params = None
        last_final_res = None
516
        for final_res in final_res_batch:
517
            last_final_res = final_res
518
            prompt_token_ids = final_res.prompt_token_ids
519
            assert prompt_token_ids is not None
520
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
521
522
            prompt_text = final_res.prompt

523
            token_ids: GenericSequence[int]
524
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
525

526
            for output in final_res.outputs:
527
528
                self._raise_if_error(output.finish_reason, request_id)

529
                assert request.max_tokens is not None
530
                if request.echo:
531
532
                    if request.return_token_ids:
                        prompt_text = ""
533
                    assert prompt_text is not None
534
535
536
537
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
538
                    else:
539
540
541
542
543
544
545
546
547
548
549
550
551
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
552
553
                else:
                    token_ids = output.token_ids
554
                    out_logprobs = output.logprobs
555
556
557
                    output_text = output.text

                if request.logprobs is not None:
558
                    assert out_logprobs is not None, "Did not output logprobs"
559
                    logprobs = self._create_completion_logprobs(
560
                        token_ids=token_ids,
561
                        top_logprobs=out_logprobs,
562
                        tokenizer=tokenizer,
563
                        num_output_top_logprobs=request.logprobs,
564
                        return_as_token_id=request.return_tokens_as_token_ids,
565
566
567
568
569
570
571
572
573
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
574
                    stop_reason=output.stop_reason,
575
                    prompt_logprobs=final_res.prompt_logprobs,
576
577
578
579
580
581
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
582
583
584
                )
                choices.append(choice_data)

585
586
                num_generated_tokens += len(output.token_ids)

587
588
589
590
591
592
593
594
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

595
596
597
598
599
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
600
            usage.prompt_tokens_details = PromptTokenUsageInfo(
601
602
                cached_tokens=last_final_res.num_cached_tokens
            )
603

604
        request_metadata.final_usage_info = usage
605
606
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
607
608
609
610
611
612
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
613
614
            kv_transfer_params=kv_transfer_params,
        )
615
616
617
618

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
619
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
620
        num_output_top_logprobs: int,
621
        tokenizer: TokenizerLike | None,
622
        initial_text_offset: int = 0,
623
        return_as_token_id: bool | None = None,
624
625
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
626
        out_text_offset: list[int] = []
627
        out_token_logprobs: list[float | None] = []
628
        out_tokens: list[str] = []
629
        out_top_logprobs: list[dict[str, float] | None] = []
630
631
632

        last_token_len = 0

633
634
635
636
637
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
638
639
640
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
641
                if should_return_as_token_id:
642
                    token = f"token_id:{token_id}"
643
644
                else:
                    if tokenizer is None:
645
646
647
648
649
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
650
651
652
                        )

                    token = tokenizer.decode(token_id)
653

654
655
656
657
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
658
659
                step_token = step_top_logprobs[token_id]

660
                token = self._get_decoded_token(
661
                    step_token,
662
663
                    token_id,
                    tokenizer,
664
                    return_as_token_id=should_return_as_token_id,
665
666
667
                )
                token_logprob = max(step_token.logprob, -9999.0)

668
669
670
671
672
673
674
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
675
676
677
678
679
680
681
682
683
684
685
686
687
688
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
689
690
691
692
693
694
695
696
697
698
699
700
701

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )