sampler.py 48.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""A layer that samples the next tokens from the model's outputs."""
3
import itertools
4
from collections.abc import Iterator
5
from dataclasses import dataclass
6
from importlib.util import find_spec
7
from math import inf
8
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
import msgspec
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
import torch
import torch.nn as nn

14
import vllm.envs as envs
15
from vllm.model_executor.layers.utils import apply_penalties
16
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
17
18
19
                                                   SamplingTensors,
                                                   SequenceGroupToSample)
from vllm.sampling_params import SamplingType
20
21
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, Logprob,
22
                           PromptLogprobs, SampleLogprobs, SequenceOutput)
23
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
26
27
28
29
30
31
32
33
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
    # yapf: disable
    from flashinfer.sampling import (
        top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

    # yapf: enable
else:
    flashinfer_top_k_top_p_sampling = None

34
35
36
37
from vllm.logger import init_logger

logger = init_logger(__name__)

Joe Runde's avatar
Joe Runde committed
38
39
40
41
42
43
44
45
46

def get_sampler() -> torch.nn.Module:
    if envs.VLLM_USE_V1:
        # Lazy import: the v1 package isn't distributed
        from vllm.v1.sample.sampler import Sampler as V1Sampler
        return V1Sampler()
    return Sampler()


47
# (num_token_ids, num_parent_ids) per sequence group.
48
SampleResultType = list[tuple[list[int], list[int]]]
49

50
51
# Types of temporary data structures used for
# computing sample_result
52
53
54
55
SampleMetadataType = dict[SamplingType, tuple[list[int],
                                              list[SequenceGroupToSample]]]
MultinomialSamplesType = dict[SamplingType, torch.Tensor]
SampleResultsDictType = dict[int, tuple[list[int], list[int]]]
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81


# Encapsulates temporary data structures for computing
# sample_result.
#
# * For multi-step scheduling: must be returned
#   by `Sampler.forward()` and used later to compute the pythonized
#   sample_result
#
# * For single-step scheduling: consumed immediately
#   inside `Sampler.forward()` to compute pythonized sample_result.
@dataclass
class SampleResultArgsType:
    sample_metadata: SampleMetadataType
    multinomial_samples: MultinomialSamplesType
    sample_results_dict: SampleResultsDictType
    sampling_metadata: SamplingMetadata
    greedy_samples: Optional[torch.Tensor]


# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# sample result types
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]

# Abbreviation of the _sample() return type
82
SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
83
84
85
86
87
88
89
90
91
92
93
94
95


class SamplerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

    This data structure implements methods, so it can be used like a list, but
    also has optional fields for device tensors.
    """

96
    outputs: list[CompletionSequenceGroupOutput]
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    # On-device tensor containing probabilities of each token.
    sampled_token_probs: Optional[torch.Tensor] = None

    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

    # Holds either (1) the pythonized sampler result (single-step scheduling)
    # or (2) what will be arguments for later deferred pythonization of the
    # sampler result (muliti-step scheduling)
    deferred_sample_results_args: Optional[SampleResultArgsType] = None

    # On-device tensor containing the sampled token ids.
    sampled_token_ids: Optional[torch.Tensor] = None
    # CPU tensor containing the sampled token ids. Used during multi-step to
    # return the sampled token ids from last rank to AsyncLLMEngine to be
    # 'broadcasted' to all other PP ranks for next step.
    sampled_token_ids_cpu: Optional[torch.Tensor] = None

116
117
118
119
120
    # On-device tensor containing the sampled token embeddings (embeddings
    # corresponding to the sampled token ids). Used when prompt embeddings are
    # specified in lieu of prompt token ids or text.
    sampled_token_embeds: Optional[torch.Tensor] = None

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None

    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

    # Optional prefill hidden states from the model
    # (used for models like EAGLE).
    prefill_hidden_states: Optional[torch.Tensor] = None

    # Time taken in the forward pass for this across all workers
    model_forward_time: Optional[float] = None

    # Time taken in the model execute function. This will include model forward,
    # block/sync across workers, cpu-gpu sync time and sampling time.
    model_execute_time: Optional[float] = None

138
    def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
139
140
141
142
143
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

144
145
146
    def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
        return iter(self.outputs)

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")

167

Woosuk Kwon's avatar
Woosuk Kwon committed
168
class Sampler(nn.Module):
169
170
171
172
173
174
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
175
    3. Apply presence, frequency and repetition penalties.
176
177
178
179
180
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
181
182
183
184
185
186

    The structure of the logits tensor is coupled with the seq_groups in
    sampling_metadata. Typically, each sequence in each seq_group has one row in
    logits for the next token to be sampled; however, for a seq_group with a
    prompt request with the prompt_logprobs sampling parameter, there are rows
    in logits for each token in the input prompt.
187
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
188

189
190
191
192
193
    def __init__(self):
        super().__init__()

        # Whether or not the SamplerOutput should have on-device tensors
        # containing the sampled token ids and probabilities. This is used by
194
        # speculative decoding and when prompt embeddings are specified.
195
        self.include_gpu_probs_tensor = False
196
        self.should_modify_greedy_probs_inplace = False
197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def _init_sampling_tensors(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ):
        """The goal here is to reuse sampling tensors between similar decode
        runs. This is possible because sampling logic does not change between
        decodes of the same sequences.
        """
        _, vocab_size = logits.shape

        # First free any existing stored sampling tensors.
        # This is necessary because some sampling tensors may
        # have pinned memory.
        self._sampling_tensors = None

        # Initialize new sampling tensors
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)

        self._sampling_tensors = sampling_tensors
        self._do_penalties = do_penalties
        self._do_top_p_top_k = do_top_p_top_k
        self._do_min_p = do_min_p

Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
    def forward(
        self,
226
        logits: torch.Tensor,
227
        sampling_metadata: SamplingMetadata,
228
    ) -> Optional[SamplerOutput]:
229
        """
230
        Single-step scheduling:
231
232
233
            * Perform GPU-side sampling computation & compute
            GPU-side logprobs tensor
            * Pythonize sampling result & logprobs tensor
234
235

        Multi-step scheduling:
236
237
238
239
240
241
242
243
            * Perform GPU-side sampling computation & compute
            GPU-side logprobs tensor
            * Defer Pythonization of sampling result & logprobs
            tensor
            * Encapsulate arguments required for deferred Pythonization
            in the
            [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput]
            structure
244

245
246
247
248
        Args:
            logits: (num_tokens, vocab_size).
            sampling_metadata: Metadata for sampling.
        """
249
        assert logits is not None
250
251
        _, vocab_size = logits.shape

252
        # Prepare sampling tensors with pinned memory to avoid blocking.
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        if not sampling_metadata.reuse_sampling_tensors:
            self._init_sampling_tensors(logits, sampling_metadata)
        elif self._do_penalties:
            # In this case, the sampling tensors logic depends on
            # "output_tokens" of a sequence. As a result, we cannot
            # reuse sampling tensors, since "output_tokens" changes
            # between decode runs.
            self._init_sampling_tensors(logits, sampling_metadata)

        assert self._sampling_tensors is not None
        sampling_tensors = self._sampling_tensors
        do_penalties = self._do_penalties
        do_top_p_top_k = self._do_top_p_top_k
        do_min_p = self._do_min_p

        logits = _apply_min_tokens_penalty(logits, sampling_metadata)
269

270
        # Apply presence and frequency penalties.
271
        if do_penalties:
272
273
274
275
276
            logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
                                     sampling_tensors.output_tokens,
                                     sampling_tensors.presence_penalties,
                                     sampling_tensors.frequency_penalties,
                                     sampling_tensors.repetition_penalties)
277

278
        # Use float32 to apply temperature scaling.
279
        # Use in-place division to avoid creating a new tensor.
280
        logits = logits.to(torch.float)
281
        logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
282

283
        if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
284
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
285
286
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
287
        if do_min_p:
288
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
289

290
291
292
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
293
294
        # Compute the log probabilities.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
295

Woosuk Kwon's avatar
Woosuk Kwon committed
296
        # Sample the next tokens.
297
        maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
298
299
300
301
302
303
304
305
306
            probs,
            logprobs,
            sampling_metadata,
            sampling_tensors,
            include_gpu_probs_tensor=self.include_gpu_probs_tensor,
            modify_greedy_probs=self._should_modify_greedy_probs_inplace,
        )

        if self.include_gpu_probs_tensor:
307
308
309
            # Since we will defer sampler result Pythonization,
            # preserve GPU-side tensors in support of later
            # deferred pythonization of logprobs
310
            assert maybe_sampled_tokens_tensor is not None
311
            on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
312
        else:
313
314
            # Since Pythonization has already happened, don't preserve
            # GPU-side tensors.
315
316
            on_device_tensors = None

317
        # Get the logprobs query results.
318
319
320
        prompt_logprobs = None
        sample_logprobs = None
        if not sampling_metadata.skip_sampler_cpu_output:
321
322
323
324
325
            # Pythonize logprobs now (GPU -> CPU); do not defer.
            assert not isinstance(maybe_deferred_sample_results,
                                  SampleResultArgsType)
            prompt_logprobs, sample_logprobs = get_logprobs(
                logprobs, sampling_metadata, maybe_deferred_sample_results)
326
327

        return _build_sampler_output(
328
            maybe_deferred_sample_results,
329
330
331
332
333
            sampling_metadata,
            prompt_logprobs,
            sample_logprobs,
            on_device_tensors=on_device_tensors,
            skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
334
335
336
337
338
339
340
341
342
343
344
345
346

    @property
    def _should_modify_greedy_probs_inplace(self) -> bool:
        """Whether or not the sampler should modify the probability distribution
        of greedily-sampled tokens such that multinomial sampling would sample
        the greedily-sampled token.

        In other words, if True then we set the probability of the greedily-
        sampled token to 1.

        This is used by speculative decoding, which requires that the sampling
        method be encoded into the probability distribution.
        """
347
        return self.should_modify_greedy_probs_inplace
348
349


350
351
352
353
def _apply_min_tokens_penalty(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
354
355
356
    """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
        have not been generated yet
    """
357
    # list of indices in logits that will be set to -inf
358
    logits_to_penalize: list[tuple[int, int]] = []
359
360
361
362
363
364
365
366
367
368
    logits_applied = 0
    for seq_group in sampling_metadata.seq_groups:
        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params

        sample_indices = seq_group.sample_indices
        logits_applied += len(sample_indices) + len(
            seq_group.prompt_logprob_indices)
        if not seq_group.do_sample:
            continue
369

370
        start_idx = sample_indices[0]
371
        min_tokens = sampling_params.min_tokens
372
373
        token_ids_to_penalize = sampling_params.all_stop_token_ids
        if min_tokens > 0 and token_ids_to_penalize:
374
            seqs_to_penalize: list[int] = []
375
            for j, seq_id in enumerate(seq_ids):
376
                seq_data = seq_group.seq_data[seq_id]
377
                if len(seq_data.output_token_ids_array) < min_tokens:
378
                    seqs_to_penalize.append(j)
379
380
381

            if seqs_to_penalize:
                # convert to the index into logits
382
                seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
383
384
385
386
387
388
389
390
391
                # itertools.product pairs each seq index with every token id
                logits_to_penalize.extend(
                    itertools.product(seqs_to_penalize, token_ids_to_penalize))

    if logits_to_penalize:
        # use zip and * to group indices along each dimension
        # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
        logits[tuple(zip(*logits_to_penalize))] = -float("inf")

392
    # verifies that no rows in logits were missed unexpectedly
393
    assert logits_applied == logits.shape[0]
394
395
396
    return logits


397
def _apply_top_k_top_p(
398
    logits: torch.Tensor,
399
400
    p: torch.Tensor,
    k: torch.Tensor,
401
) -> torch.Tensor:
402
403
404
405
406
407
408
409
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

    # Apply top-k.
    top_k_mask = logits_sort.size(1) - k.to(torch.long)
    # Get all the top_k values.
    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
    top_k_mask = logits_sort < top_k_mask
    logits_sort.masked_fill_(top_k_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411

    # Apply top-p.
412
    probs_sort = logits_sort.softmax(dim=-1)
413
414
415
416
417
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
    # at least one
    top_p_mask[:, -1] = False
    logits_sort.masked_fill_(top_p_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
418
419

    # Re-sort the probabilities.
420
421
422
    logits = torch.empty_like(logits_sort).scatter_(dim=-1,
                                                    index=logits_idx,
                                                    src=logits_sort)
423
    return logits
424
425


Roy's avatar
Roy committed
426
427
def _apply_min_p(
    logits: torch.Tensor,
428
    min_p: torch.Tensor,
Roy's avatar
Roy committed
429
430
431
432
433
434
435
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
436
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
437
    tokens_to_remove = probs < scaled_min_p
438
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
439
440
441
442

    return logits


443
def _greedy_sample(
444
    selected_seq_groups: list[SequenceGroupToSample],
445
    samples: torch.Tensor,
446
) -> SampleResultType:
447
448
449
450
451
452
453
454
455
456
457
458
    """Run greedy sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        samples: (num_selected_samples,) A tensor of samples. The length of
            samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
459
    samples_lst = samples.tolist()
460
    sample_idx = 0
461
    results: SampleResultType = []
462
    for seq_group in selected_seq_groups:
463
464
465
466
467
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
468
469
470
471
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
472
        next_token_ids = [samples_lst[sample_idx]]
473
474
475
476
477
478
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
479
    selected_seq_groups: list[SequenceGroupToSample],
480
    random_samples: torch.Tensor,
481
) -> SampleResultType:
482
483
484
485
486
487
488
489
490
491
492
493
    """Run random sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        random_samples: (num_selected_samples,) A tensor of samples. The
            length of samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
494
    # Find the maximum n value of the prompt phase requests.
495
    random_samples = random_samples.cpu()
496
    sample_idx = 0
497
    results: SampleResultType = []
498
499
500
501
502
503
504
505
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params
        is_prompt = seq_group.is_prompt
506
507
508
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
509
            parent_ids = [0] * sampling_params.n
510
            next_token_ids = random_samples[
511
                sample_idx, :sampling_params.n].tolist()
512
513
514
515
516
517
518
519
520
521
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


522
523
524
525
526
527
528
529
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
530
    seq_groups: Optional[list[SequenceGroupToSample]] = None,
Nick Hill's avatar
Nick Hill committed
531
) -> torch.Tensor:
532
    if num_samples > 1:
533
        probs = probs.repeat_interleave(num_samples, dim=0)
Nick Hill's avatar
Nick Hill committed
534
535
536
537
538
    q = torch.empty_like(probs)
    if seq_groups is None:
        q.exponential_()
    else:
        sample_idx = 0
539
540
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
541
542
543
544
545
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            q[sample_idx:sample_idx +
              stride].exponential_(generator=seq_group.generator)
            sample_idx += stride
546
547
548
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


549
550
def _top_k_top_p_multinomial_with_flashinfer(
        probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
551
        num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
552
553
554
555
    if num_samples > 1:
        probs = probs.repeat_interleave(num_samples, dim=0)
        top_ks = top_ks.repeat_interleave(num_samples)
        top_ps = top_ps.repeat_interleave(num_samples)
556
    batch_next_token_ids = flashinfer_top_k_top_p_sampling(
557
558
559
560
561
562
563
        probs,
        top_ks,
        top_ps,
    )
    return batch_next_token_ids.view(-1, num_samples)


564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def get_pythonized_sample_results(
        sample_result_args: SampleResultArgsType) -> SampleResultType:
    '''This function consumes GPU-side sampler results and computes
    Pythonized CPU-side sampler results (GPU -> CPU sync.)

    Single-step scheduling: this function is invoked at sampling-time
    for immediate Pythonization.

    Multi-step scheduling: Pythonization is deferred until after multiple
    GPU-side steps have been completed.

    Args:
      sample_result_args: GPU-side inputs to the Pythonization process

    Returns:
      Pythonized sampler results
    '''

    (
        sample_metadata,
        sampling_metadata,
        greedy_samples,
        multinomial_samples,
        sample_results_dict,
    ) = (
        sample_result_args.sample_metadata,
        sample_result_args.sampling_metadata,
        sample_result_args.greedy_samples,
        sample_result_args.multinomial_samples,
        sample_result_args.sample_results_dict,
    )

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
        (seq_group_id, seq_groups) = sample_metadata[sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, greedy_samples)
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
            sample_results = _random_sample(seq_groups,
                                            multinomial_samples[sampling_type])
        sample_results_dict.update(zip(seq_group_id, sample_results))

    return [
        sample_results_dict.get(i, ([], []))
        for i in range(len(sampling_metadata.seq_groups))
    ]


613
def _sample_with_torch(
614
615
    probs: torch.Tensor,
    logprobs: torch.Tensor,
616
    sampling_metadata: SamplingMetadata,
617
    sampling_tensors: SamplingTensors,
618
619
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
620
621
622
) -> SampleReturnType:
    '''Torch-oriented _sample() implementation.

623
    Single-step scheduling:
624
625
626
627
628
629
630
631
632
    * Perform GPU-side sampling computation
    * Immediately Pythonize sampling result

    Multi-step scheduling:
    * Perform GPU-side sampling computation
    * Defer Pythonization & preserve GPU-side
      tensors required for Pythonization
    '''

633
    categorized_seq_group_ids: dict[SamplingType, list[int]] = {
634
635
636
        t: []
        for t in SamplingType
    }
637
638
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
639
        sampling_params = seq_group.sampling_params
640
641
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
642

643
644
645
646
    sample_results_dict: SampleResultsDictType = {}
    sample_metadata: SampleMetadataType = {}
    multinomial_samples: MultinomialSamplesType = {}
    greedy_samples: Optional[torch.Tensor] = None
647

648
649
    # Create output tensor for sampled token ids.
    if include_gpu_probs_tensor:
650
651
652
653
        sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
                                              VLLM_INVALID_TOKEN_ID,
                                              dtype=torch.long,
                                              device=logprobs.device)
654
655
656
    else:
        sampled_token_ids_tensor = None

657
658
    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
659
    for sampling_type in SamplingType:
660
        sample_indices = categorized_sample_indices[sampling_type]
661
        num_tokens = len(sample_indices)
662
663
        if num_tokens == 0:
            continue
664

665
666
667
668
        seq_group_id = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
        sample_metadata[sampling_type] = (seq_group_id, seq_groups)
        long_sample_indices = sample_indices.long()
669
        if sampling_type == SamplingType.GREEDY:
670
            greedy_samples = torch.argmax(logprobs[long_sample_indices],
671
                                          dim=-1)
672

673
            if sampled_token_ids_tensor is not None:
674
675
676
677
678
679
680
681
682
683
684
685
                # Store sampled tokens in output tensor.
                sampled_token_ids_tensor[
                    long_sample_indices] = greedy_samples.unsqueeze(-1)

            if modify_greedy_probs:
                # If required, modify the probabilities such that sampling from
                # the modified distribution would always sample the argmax
                # token id.
                _modify_greedy_probs_inplace(logprobs, probs,
                                             long_sample_indices,
                                             greedy_samples)

Nick Hill's avatar
Nick Hill committed
686
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
687
            max_n_in_batch = 1
688
689
690
            for seq_group in seq_groups:
                if seq_group.is_prompt:
                    sampling_params = seq_group.sampling_params
691
                    max_n_in_batch = max(max_n_in_batch, sampling_params.n)
692
693
694
695
            seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
                              seq_groups)

            if flashinfer_top_k_top_p_sampling is not None:
696
697
698
699
700
701
702
703
                logger.warning("FlashInfer 0.2.3+ does not support "
                               "per-request generators. Falling back to "
                               "PyTorch-native implementation.")

            multinomial_samples[sampling_type] = _multinomial(
                probs[long_sample_indices],
                max_n_in_batch,
                seq_groups=seq_groups_arg)
704

705
            if sampled_token_ids_tensor is not None:
706
                # Store sampled tokens in output tensor.
707
708
                sampled_token_ids_tensor[long_sample_indices] = \
                    multinomial_samples[sampling_type].to(torch.long)
709

710
711
712
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

713
714
715
716
717
718
719
720
721
    # Encapsulate arguments for computing Pythonized sampler
    # results, whether deferred or otherwise.
    maybe_deferred_args = SampleResultArgsType(
        sampling_metadata=sampling_metadata,
        sample_metadata=sample_metadata,
        multinomial_samples=multinomial_samples,
        greedy_samples=greedy_samples,
        sample_results_dict=sample_results_dict)

722
    if not sampling_metadata.skip_sampler_cpu_output:
723
724
725
726
727
        # GPU<->CPU sync happens here.
        # This also converts the sampler output to a Python object.
        # Return Pythonized sampler result & sampled token ids
        return get_pythonized_sample_results(
            maybe_deferred_args), sampled_token_ids_tensor
728
    else:
729
730
731
732
733
734
        # Defer sampler result Pythonization; return deferred
        # Pythonization args & sampled token ids
        return (
            maybe_deferred_args,
            sampled_token_ids_tensor,
        )
735
736


737
def _sample(
738
739
740
741
742
743
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
744
) -> SampleReturnType:
745
746
747
748
749
750
751
752
753
754
    """
    Args:
        probs: (num_query_tokens_in_batch, num_vocab)
        logprobs: (num_query_tokens_in_batch, num_vocab)
        sampling_metadata: The metadata for a batch for sampling.
        sampling_tensors: Tensors that include sampling related metadata.

    Returns:
        (next_token_ids, parent_seq_ids) for each seq group in a batch.
            If sampling is skipped, it returns ([], [])
755
        sampled_token_ids_tensor: A tensor of sampled token ids.
756
    """
757
758
759
760
    return _sample_with_torch(
        probs,
        logprobs,
        sampling_metadata,
761
        sampling_tensors,
762
763
764
        include_gpu_probs_tensor=include_gpu_probs_tensor,
        modify_greedy_probs=modify_greedy_probs,
    )
765
766


767
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
768
769
770
771
772
773
    """
    This function calculates the ranks of the chosen tokens in a logprob tensor.

    Args:
        x (torch.Tensor): 2D logprob tensor of shape (N, M)
                        where N is the no. of tokens and M is the vocab dim.
774
        indices (torch.Tensor): List of chosen token indices.
775
776
777

    Returns:
        torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
778
                    Each element in the returned tensor represents the rank
779
780
                    of the chosen token in the input logprob tensor.
    """
781
782
    vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
             indices]
783
784
785
    result = (x > vals[:, None])
    del vals
    return result.sum(1).add_(1)
786
787


788
def get_logprobs(
789
    logprobs: torch.Tensor,
790
    sampling_metadata: SamplingMetadata,
791
    sample_results: SampleResultType,
792
) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]:
793
    """Return sample logprobs and prompt logprobs.
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

    The logic consists of 3 parts.
    - Select indices to compute logprob from, ranks of token ids, and
        the top k token ids from logprobs.
    - Compute prompt logprobs if required.
    - Compute sample logprobs if required.

    Args:
        logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
            logprob per vocab. Sequence groups' query tokens are batched in a
            single flattened tensor. For example, assuming there are N
            seq groups, it is sorted by prefill tokens for seq_group_1 (if
            prompt logprob is enabled), decode tokens for seq_group_1 (if
            sampling is required), prefill tokens for seq_group_2, ...
        sampling_metadata: The sampling metadata.
        sample_results: (num_seq_groups) The tuple of (next_token_ids,
            parent_ids) for each sequence group. When beam search is enabled,
            sample_results can contain different number of seq_ids from
            sampling_metadata.seq_groups. It is because beam search creates
            2 * BEAM_WIDTH number of samples (whereas there are only up to
            BEAM_WIDTH number of seq_ids).

    Returns:
        A tuple of prompt and sample logprobs per sequence group in a batch.
    """
    # The index of query token to calculate logprobs. It includes both
    # prompt and sample logprob indices.
821
    query_indices: list[int] = []
822
    # The next token ids to get the logprob value from.
823
    next_token_ids: list[int] = []
824
    # The largest requested number of logprobs. We find logprobs as many as the
825
826
827
    # largest num logprobs in this API. If every logprobs is None, it will be
    # set to -1.
    largest_num_logprobs = -1
828
829
830
831
832
833
834
835
836

    # Select indices to compute logprob from, ranks of token ids, and the top
    # k token ids from logprobs.
    for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
                                          sample_results):
        sampling_params = seq_group.sampling_params

        # Update indices and tokens for prompt logprobs.
        if (seq_group.is_prompt
837
838
839
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
            next_prompt_tokens = _get_next_prompt_tokens(seq_group)
            query_indices.extend(seq_group.prompt_logprob_indices)
            next_token_ids.extend(next_prompt_tokens)

        # Update indices and next tokenes for sample logprob.
        if seq_group.do_sample:
            token_ids, parent_seq_ids = sample_result
            # NOTE: We cannot directly use sample_indices because
            # sample_indices only contain parent seq_ids of a previous step.
            # The current step may have different number of seq_ids, and
            # we can obtain it from `sample_result[1]`.
            query_idx = seq_group.sample_indices[0]
            query_indices.extend(
                [query_idx + parent_id for parent_id in parent_seq_ids])
            next_token_ids.extend(token_ids)

            if sampling_params.logprobs is not None:
                largest_num_logprobs = max(largest_num_logprobs,
                                           sampling_params.logprobs)

        assert len(next_token_ids) == len(query_indices)

    if len(query_indices) == 0:
863
864
        empty_sampled_logprob: SampleLogprobs = []
        empty_prompt_logprob: Optional[PromptLogprobs] = None
865
866
867
        num_seq_groups = len(sampling_metadata.seq_groups)
        return [empty_prompt_logprob
                ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
868

869
870
871
872
873
    selected_logprobs, ranks = None, None
    top_logprobs, top_token_ids = None, None

    # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
    # skip the whole logprob calculation.
874
    if largest_num_logprobs >= 0:
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
        query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
        next_token_ids_gpu = torch.tensor(next_token_ids,
                                          device=logprobs.device)

        # (num_selected_query_tokens, num_logprobs). Note that query_indices can
        # contain duplicates if beam search is enabled.
        selected_logprobs = logprobs[[
            query_indices_gpu,
            next_token_ids_gpu,
        ]]
        ranks = _get_ranks(
            logprobs[query_indices_gpu],
            next_token_ids_gpu,
        )
        assert selected_logprobs.shape[0] == ranks.shape[0]

        # We need to compute top k only if there exists logprobs > 0.
        if largest_num_logprobs > 0:
            # Logprobs of topk tokens for a batch of sequence groups.
            # (num_query_tokens_across_batch).
            top_logprobs, top_token_ids = torch.topk(logprobs,
                                                     largest_num_logprobs,
                                                     dim=-1)
            top_logprobs = top_logprobs.to('cpu')
            top_token_ids = top_token_ids.to('cpu')
900

901
902
        selected_logprobs = selected_logprobs.to('cpu')
        ranks = ranks.to('cpu')
903
904

    # Find prompt/sample logprobs.
905
906
    prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = []
    sample_logprobs_per_seq_group: list[SampleLogprobs] = []
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    top_logprob_idx = 0
    selected_logprobs_idx = 0

    for seq_group, sample_result in zip(sampling_metadata.seq_groups,
                                        sample_results):
        (prompt_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_prompt_logprob_if_needed(
             seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
             selected_logprobs_idx, top_logprob_idx)
        prompt_logprobs_per_seq_group.append(prompt_logprobs)

        (sampled_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_sampled_logprob_if_needed(
             seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
             top_logprobs, selected_logprobs_idx, top_logprob_idx)
        sample_logprobs_per_seq_group.append(sampled_logprobs)

    return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group


def _get_prompt_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the prompt logprob from a sequence group if needed."""
    sampling_params = seq_group.sampling_params
    is_prompt = seq_group.is_prompt

    # Find prompt logprobs
    prompt_logprobs: Optional[PromptLogprobs] = None
942
    if is_prompt and sampling_params.prompt_logprobs is not None:
943
944
945
        prompt_logprobs = []
        num_logprobs = sampling_params.prompt_logprobs
        next_prompt_tokens = _get_next_prompt_tokens(seq_group)
946
947
948
949
950
951
952
953
954
        # Pre-select indexes and create a list. It is faster than calling .item
        # repetitively.
        selected_logprob_items = selected_logprobs[
            selected_logprobs_idx:selected_logprobs_idx +
            len(next_prompt_tokens)].tolist()
        rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                           len(next_prompt_tokens)].tolist()

        for idx, token_id in enumerate(next_prompt_tokens):
955
956
            # Calculate the prompt logprob of the real prompt tokens.
            # {token_id: (logprob, rank_from_vocab)}
957
            prompt_logprobs_dict: dict[int, tuple[float, int]] = {
958
                token_id: (selected_logprob_items[idx], rank_items[idx])
959
            }
960

961
962
            # Add top K prompt logprobs along with its rank.
            if num_logprobs > 0:
963
964
965
966
967
968
969
970
971
972
973
974
                top_ids = top_token_ids[
                    top_logprob_idx, :num_logprobs].tolist()
                top_probs = top_logprobs[
                    top_logprob_idx, :num_logprobs].tolist()
                # Top K is already sorted by rank, so we can use 1 ~
                # num_logprobs + 1 for rank.
                top_ranks = range(1, num_logprobs + 1)
                prompt_logprobs_dict.update({
                    top_id: (top_prob, rank)
                    for top_id, top_prob, rank in zip(top_ids, top_probs,
                                                      top_ranks)
                })
975
976
977
978
979
980
            prompt_logprobs.append({
                token_id: Logprob(*logprob_and_rank)
                for token_id, logprob_and_rank in prompt_logprobs_dict.items()
            })
            # + 1 to go to the next prompt token.
            top_logprob_idx += 1
981
982
983

        # + len(next_prompt_tokens) to go to the next prompt.
        selected_logprobs_idx += len(next_prompt_tokens)
984
985
986
987
988
    return prompt_logprobs, top_logprob_idx, selected_logprobs_idx


def _get_sampled_logprob_if_needed(
    seq_group: SequenceGroupToSample,
989
    sample_result: tuple[list[int], list[int]],
990
991
992
993
994
995
996
997
998
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the sample logprob if needed."""
    seq_ids = seq_group.seq_ids
999
    num_logprobs = seq_group.sampling_params.logprobs
1000
1001
1002
1003
1004
    sampled_logprobs: SampleLogprobs = []
    next_token_ids, parent_seq_ids = sample_result

    if seq_group.do_sample:
        assert len(next_token_ids) > 0
1005
        if num_logprobs is None:
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
            for next_token_id in next_token_ids:
                # Use a dummy logprob
                sampled_logprobs.append({next_token_id: Logprob(inf)})
        else:
            # Pre-select items from tensor. tolist() is faster than repetitive
            # `.item()` calls.
            selected_logprob_items = selected_logprobs[
                selected_logprobs_idx:selected_logprobs_idx +
                len(next_token_ids)].tolist()
            rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                               len(next_token_ids)].tolist()
            for idx, (next_token_id, parent_id) in enumerate(
                    zip(next_token_ids, parent_seq_ids)):
                # Get the logprob of a sampled token.
                sampled_logprobs_dict = {
                    next_token_id:
                    (selected_logprob_items[idx], rank_items[idx])
                }
                if num_logprobs is not None and num_logprobs > 0:
                    # Get top K logprobs.
                    top_ids = top_token_ids[top_logprob_idx +
                                            parent_id, :num_logprobs].tolist()
                    top_probs = top_logprobs[
                        top_logprob_idx + parent_id, :num_logprobs].tolist()
                    # Top K is already sorted by rank, so we can use 1 ~
                    # num_logprobs + 1 for rank.
                    top_ranks = range(1, num_logprobs + 1)
                    sampled_logprobs_dict.update({
                        top_id: (top_prob, rank)
                        for top_id, top_prob, rank in zip(
                            top_ids, top_probs, top_ranks)
                    })

                sampled_logprobs.append({
                    token_id: Logprob(*logprob_and_rank)
                    for token_id, logprob_and_rank in
                    sampled_logprobs_dict.items()
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
                })

        # NOTE: This part of code is not intuitive. `selected_logprobs` include
        # logprobs for the current step, which has len(next_token_ids) tokens
        # per sequence group. `logprobs` includes logprobs from the previous
        # steps, which has len(seq_ids) tokens per sequence group.

        # Iterate to the next sequence group in a batch.
        selected_logprobs_idx += len(next_token_ids)
        # Iterate to the next sequence group in a batch.
1053
1054
        top_logprob_idx += len(seq_ids)
    return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
1055
1056


1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
                                 sample_indices: torch.Tensor,
                                 greedy_samples: torch.Tensor) -> None:
    """Modify the probability distributions of the greedily-sampled tokens such
    that each sampled token has a "probability" of 1.0. This is required by
    speculative decoding, which depends on the sampling method being encoded
    within the probability distribution for correctness.

    # Why do we only need to do this for greedy sampling?

    vLLM's sampler performs the following steps for greedy or multinomial
    (random) sampling:
        1. Get logits from model.
        2. Modify logits according to per-sequence sampling parameters.
            - Multiply by temperature, top-k and top-p masking, penalize tokens
                according to their frequency, etc.
        3. Sample a token.
            - Random sampling simply samples from the modified probability
                distribution.
            - Greedy sampling performs `argmax` to obtain the token with the
                highest likelihood.
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    Ignoring greedy sampling for a moment, we find that the computed probability
    distribution has the following property: we can sample from it independently
    and find that the token sampled by the Sampler has a frequency corresponding
    to how often we see it in our sampling. In other words, for tokens sampled
    with vLLM's random SamplingType, the computed probability distribution
    encodes the sampling methodology completely.

    Greedy sampling does not normally have this property. vLLM modifies logits
    according to sampling params, then performs `argmax`, then returns the
    sampled token and the computed probability distribution. If we sample from
    the distribution, we'll find the likelihood of the greedily-sampled token
    is not always 1.0.

    Since lossless speculative decoding requires that the sampling methodology
    be encoded within the probability distribution, we are motivated to modify
    the probability distribution such that the sampled token has probability 1
    when speculative decoding is used.

    NOTE: Alternatively, we could use an extremely low temperature to achieve
    greedy sampling using multinomial computation and unite the codepaths. This
    has implications on the overall design of the sampler, e.g. how to record
    accurate logprobs for the user, so this improvement is deferred to later.
    """
1102
    # NOTE: logprobs are not modified so they can be returned to the user.
1103
1104
1105
1106
    probs[sample_indices, :] = 0
    probs[sample_indices, greedy_samples] = 1.0


1107
def _build_sampler_output(
1108
    maybe_deferred_sample_results: MaybeDeferredSampleResultType,
1109
    sampling_metadata: SamplingMetadata,
1110
1111
1112
    prompt_logprobs: Optional[list[Optional[PromptLogprobs]]],
    sample_logprobs: Optional[list[SampleLogprobs]],
    on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor,
1113
                                      torch.Tensor]],
1114
    skip_sampler_cpu_output: bool = False,
1115
) -> SamplerOutput:
1116
1117
1118
1119
1120
1121
1122
1123
    """Construct Python objects with the output of sampling.

    Args:
        on_device_tensors: Tuple containing on-device tensors with the
            probabilities used in sampling and the sampled token ids. This
            allows post-processing without copies to CPU/serialization, e.g. in
            speculative decoding rejection sampling.
    """
1124
    sampler_output: list[CompletionSequenceGroupOutput] = []
1125
1126
1127
1128
1129

    if skip_sampler_cpu_output:
        assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
        deferred_sample_results_args = maybe_deferred_sample_results
    else:
1130
1131
        assert prompt_logprobs is not None
        assert sample_logprobs is not None
1132
1133
        assert not isinstance(maybe_deferred_sample_results,
                              SampleResultArgsType)
1134
1135
1136
1137
        assert len(sampling_metadata.seq_groups) \
            == len(maybe_deferred_sample_results) \
            == len(prompt_logprobs) \
            == len(sample_logprobs)
1138
        deferred_sample_results_args = None
1139
1140
1141

        for (seq_group, sample_result, group_prompt_logprobs,
             group_sample_logprobs) in zip(sampling_metadata.seq_groups,
1142
1143
                                           maybe_deferred_sample_results,
                                           prompt_logprobs, sample_logprobs):
1144
1145
            seq_ids = seq_group.seq_ids
            next_token_ids, parent_ids = sample_result
1146
            seq_outputs: list[SequenceOutput] = []
1147
1148
1149
1150
1151
1152
1153
1154
            for parent_id, next_token_id, logprobs in zip(
                    parent_ids, next_token_ids, group_sample_logprobs):
                seq_outputs.append(
                    SequenceOutput(seq_ids[parent_id], next_token_id,
                                   logprobs))
            sampler_output.append(
                CompletionSequenceGroupOutput(seq_outputs,
                                              group_prompt_logprobs))
1155
1156
1157

    # If not specified, store None values in SamplerOutput.
    if on_device_tensors is not None:
1158
1159
        (sampled_token_probs, logprobs_tensor,
         sampled_token_ids) = on_device_tensors
1160
    else:
1161
1162
        sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
                                                                   None)
1163
1164
1165
1166
1167

    return SamplerOutput(
        outputs=sampler_output,
        sampled_token_probs=sampled_token_probs,
        sampled_token_ids=sampled_token_ids,
1168
        logprobs=logprobs_tensor,
1169
        deferred_sample_results_args=deferred_sample_results_args)
1170
1171


1172
1173
def _get_next_prompt_tokens(
        seq_group: SequenceGroupToSample) -> tuple[int, ...]:
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
    """Get a list of next prompt tokens to compute logprob from a
        given sequence group.

    It is used to compute prompt logprob. Imagine you have logprob for each
    query token. Query token needs to know the next prompt token id to compute
    prompt logprob. This is a helper to obtain next prompt token ids.

    This API has to be used only when the caller knows seq_group is in prefill
    stage.

    Returns:
        A list of next prompt tokens to compute logprob.
    """
    assert seq_group.is_prompt, (
        "Caller should ensure the sequence group is in a prefill stage.")
    seq_ids = seq_group.seq_ids
1190
1191
    query_len = seq_group.query_len
    assert query_len is not None
1192
1193
1194
1195
1196
1197
1198
    # prompt has only 1 seq id.
    assert len(seq_ids) == 1
    seq_data = seq_group.seq_data[seq_ids[0]]
    computed_len = seq_data.get_num_computed_tokens()
    prompt_tokens = seq_data.prompt_token_ids
    # +1 because we are looking for a next prompt token.
    next_token_index_start = computed_len + 1
1199
    next_token_index_end = min(computed_len + query_len + 1,
1200
1201
1202
1203
                               len(prompt_tokens))
    next_prompt_tokens = prompt_tokens[
        next_token_index_start:next_token_index_end]
    return next_prompt_tokens