sampler.py 52.9 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
import itertools
3
import warnings
4
from dataclasses import dataclass
5
from importlib.util import find_spec
6
from math import inf
7
from typing import Dict, Iterator, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import msgspec
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
import torch
import torch.nn as nn

13
import vllm.envs as envs
14
from vllm.model_executor.layers.utils import apply_penalties
15
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
16
17
18
                                                   SamplingTensors,
                                                   SequenceGroupToSample)
from vllm.sampling_params import SamplingType
19
20
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, Logprob,
21
                           PromptLogprobs, SampleLogprobs, SequenceOutput)
22
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
25
26
27
28
29
30
31
32
33
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
    import flashinfer.sampling
    # yapf: disable
    from flashinfer.sampling import (
        top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

    # yapf: enable
else:
    flashinfer_top_k_top_p_sampling = None

Joe Runde's avatar
Joe Runde committed
34
35
36
37
38
39
40
41
42

def get_sampler() -> torch.nn.Module:
    if envs.VLLM_USE_V1:
        # Lazy import: the v1 package isn't distributed
        from vllm.v1.sample.sampler import Sampler as V1Sampler
        return V1Sampler()
    return Sampler()


43
44
45
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Types of temporary data structures used for
# computing sample_result
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
                                              List[SequenceGroupToSample]]]
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]


# Encapsulates temporary data structures for computing
# sample_result.
#
# * For multi-step scheduling: must be returned
#   by `Sampler.forward()` and used later to compute the pythonized
#   sample_result
#
# * For single-step scheduling: consumed immediately
#   inside `Sampler.forward()` to compute pythonized sample_result.
@dataclass
class SampleResultArgsType:
    sample_metadata: SampleMetadataType
    multinomial_samples: MultinomialSamplesType
    sample_results_dict: SampleResultsDictType
    sampling_metadata: SamplingMetadata
    greedy_samples: Optional[torch.Tensor]
    beam_search_logprobs: Optional[torch.Tensor]


# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# sample result types
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]

# Abbreviation of the _sample() return type
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]


class SamplerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

    This data structure implements methods, so it can be used like a list, but
    also has optional fields for device tensors.
    """

    outputs: List[CompletionSequenceGroupOutput]

    # On-device tensor containing probabilities of each token.
    sampled_token_probs: Optional[torch.Tensor] = None

    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

    # Holds either (1) the pythonized sampler result (single-step scheduling)
    # or (2) what will be arguments for later deferred pythonization of the
    # sampler result (muliti-step scheduling)
    deferred_sample_results_args: Optional[SampleResultArgsType] = None

    # On-device tensor containing the sampled token ids.
    sampled_token_ids: Optional[torch.Tensor] = None
    # CPU tensor containing the sampled token ids. Used during multi-step to
    # return the sampled token ids from last rank to AsyncLLMEngine to be
    # 'broadcasted' to all other PP ranks for next step.
    sampled_token_ids_cpu: Optional[torch.Tensor] = None

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None

    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

    # Optional prefill hidden states from the model
    # (used for models like EAGLE).
    prefill_hidden_states: Optional[torch.Tensor] = None

    # Time taken in the forward pass for this across all workers
    model_forward_time: Optional[float] = None

    # Time taken in the model execute function. This will include model forward,
    # block/sync across workers, cpu-gpu sync time and sampling time.
    model_execute_time: Optional[float] = None
zhuwenwen's avatar
zhuwenwen committed
129
130
131
132
133
134
135
136
137
138
    
    # Optional lm_head logits from the model.
    logits: Optional[torch.Tensor] = None

    # tree-style cartesian candidates
    cart_candidates: Optional[torch.Tensor] = None

    # tree-style cartesian candidates
    tree_attn_masks: Optional[torch.Tensor] = None

139

140
    def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
141
142
143
144
145
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

146
147
148
    def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
        return iter(self.outputs)

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
167
168
169
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, "
            f"logits={self.logits}, "
            f"tree_attn_masks={self.tree_attn_masks})")
170

171

Woosuk Kwon's avatar
Woosuk Kwon committed
172
class Sampler(nn.Module):
173
174
175
176
177
178
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
179
    3. Apply presence, frequency and repetition penalties.
180
181
182
183
184
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
185
186
187
188
189
190

    The structure of the logits tensor is coupled with the seq_groups in
    sampling_metadata. Typically, each sequence in each seq_group has one row in
    logits for the next token to be sampled; however, for a seq_group with a
    prompt request with the prompt_logprobs sampling parameter, there are rows
    in logits for each token in the input prompt.
191
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
192

193
194
195
196
197
198
199
    def __init__(self):
        super().__init__()

        # Whether or not the SamplerOutput should have on-device tensors
        # containing the sampled token ids and probabilities. This is used by
        # speculative decoding.
        self.include_gpu_probs_tensor = False
200
        self.should_modify_greedy_probs_inplace = False
201

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    def _init_sampling_tensors(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ):
        """The goal here is to reuse sampling tensors between similar decode
        runs. This is possible because sampling logic does not change between
        decodes of the same sequences.
        """
        _, vocab_size = logits.shape

        # First free any existing stored sampling tensors.
        # This is necessary because some sampling tensors may
        # have pinned memory.
        self._sampling_tensors = None

        # Initialize new sampling tensors
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)

        self._sampling_tensors = sampling_tensors
        self._do_penalties = do_penalties
        self._do_top_p_top_k = do_top_p_top_k
        self._do_min_p = do_min_p

Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
    def forward(
        self,
230
        logits: torch.Tensor,
231
        sampling_metadata: SamplingMetadata,
232
    ) -> Optional[SamplerOutput]:
233
        """
234
235
236
237
238
239
240
241
242
243
244
245
246
        Single-step scheduling:
        * Perform GPU-side sampling computation & compute
          GPU-side logprobs tensor
        * Pythonize sampling result & logprobs tensor

        Multi-step scheduling:
        * Perform GPU-side sampling computation & compute
          GPU-side logprobs tensor
        * Defer Pythonization of sampling result & logprobs
          tensor
        * Encapsulate arguments required for deferred Pythonization
          in the :class:`SamplerOutput` structure

247
248
249
250
        Args:
            logits: (num_tokens, vocab_size).
            sampling_metadata: Metadata for sampling.
        """
251
        assert logits is not None
252
253
        _, vocab_size = logits.shape

254
        # Prepare sampling tensors with pinned memory to avoid blocking.
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        if not sampling_metadata.reuse_sampling_tensors:
            self._init_sampling_tensors(logits, sampling_metadata)
        elif self._do_penalties:
            # In this case, the sampling tensors logic depends on
            # "output_tokens" of a sequence. As a result, we cannot
            # reuse sampling tensors, since "output_tokens" changes
            # between decode runs.
            self._init_sampling_tensors(logits, sampling_metadata)

        assert self._sampling_tensors is not None
        sampling_tensors = self._sampling_tensors
        do_penalties = self._do_penalties
        do_top_p_top_k = self._do_top_p_top_k
        do_min_p = self._do_min_p

        logits = _apply_min_tokens_penalty(logits, sampling_metadata)
271

272
        # Apply presence and frequency penalties.
273
        if do_penalties:
274
275
276
277
278
            logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
                                     sampling_tensors.output_tokens,
                                     sampling_tensors.presence_penalties,
                                     sampling_tensors.frequency_penalties,
                                     sampling_tensors.repetition_penalties)
279

280
        # Use float32 to apply temperature scaling.
281
        # Use in-place division to avoid creating a new tensor.
282
        logits = logits.to(torch.float)
283
        logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
284

285
        if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
286
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
287
288
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
289
        if do_min_p:
290
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
291

292
293
294
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
295
296
        # Compute the log probabilities.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
297

Woosuk Kwon's avatar
Woosuk Kwon committed
298
        # Sample the next tokens.
299
        maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
300
301
302
303
304
305
306
307
308
            probs,
            logprobs,
            sampling_metadata,
            sampling_tensors,
            include_gpu_probs_tensor=self.include_gpu_probs_tensor,
            modify_greedy_probs=self._should_modify_greedy_probs_inplace,
        )

        if self.include_gpu_probs_tensor:
309
310
311
            # Since we will defer sampler result Pythonization,
            # preserve GPU-side tensors in support of later
            # deferred pythonization of logprobs
312
            assert maybe_sampled_tokens_tensor is not None
313
            on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
314
        else:
315
316
            # Since Pythonization has already happened, don't preserve
            # GPU-side tensors.
317
318
            on_device_tensors = None

319
        # Get the logprobs query results.
320
321
322
        prompt_logprobs = None
        sample_logprobs = None
        if not sampling_metadata.skip_sampler_cpu_output:
323
324
325
326
327
            # Pythonize logprobs now (GPU -> CPU); do not defer.
            assert not isinstance(maybe_deferred_sample_results,
                                  SampleResultArgsType)
            prompt_logprobs, sample_logprobs = get_logprobs(
                logprobs, sampling_metadata, maybe_deferred_sample_results)
328
329

        return _build_sampler_output(
330
            maybe_deferred_sample_results,
331
332
333
334
            sampling_metadata,
            prompt_logprobs,
            sample_logprobs,
            on_device_tensors=on_device_tensors,
335
            skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
336
            logits=logits)
337
338
339
340
341
342
343
344
345
346
347
348
349

    @property
    def _should_modify_greedy_probs_inplace(self) -> bool:
        """Whether or not the sampler should modify the probability distribution
        of greedily-sampled tokens such that multinomial sampling would sample
        the greedily-sampled token.

        In other words, if True then we set the probability of the greedily-
        sampled token to 1.

        This is used by speculative decoding, which requires that the sampling
        method be encoded into the probability distribution.
        """
350
        return self.should_modify_greedy_probs_inplace
351
352


353
354
355
356
def _apply_min_tokens_penalty(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
357
358
359
    """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
        have not been generated yet
    """
360
    # list of indices in logits that will be set to -inf
361
    logits_to_penalize: List[Tuple[int, int]] = []
362
363
364
365
366
367
368
369
370
371
    logits_applied = 0
    for seq_group in sampling_metadata.seq_groups:
        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params

        sample_indices = seq_group.sample_indices
        logits_applied += len(sample_indices) + len(
            seq_group.prompt_logprob_indices)
        if not seq_group.do_sample:
            continue
372

373
        start_idx = sample_indices[0]
374
        min_tokens = sampling_params.min_tokens
375
376
        token_ids_to_penalize = sampling_params.all_stop_token_ids
        if min_tokens > 0 and token_ids_to_penalize:
377
            seqs_to_penalize: List[int] = []
378
            for j, seq_id in enumerate(seq_ids):
379
                seq_data = seq_group.seq_data[seq_id]
380
                if len(seq_data.output_token_ids_array) < min_tokens:
381
                    seqs_to_penalize.append(j)
382
383
384

            if seqs_to_penalize:
                # convert to the index into logits
385
                seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
386
387
388
389
390
391
392
393
394
                # itertools.product pairs each seq index with every token id
                logits_to_penalize.extend(
                    itertools.product(seqs_to_penalize, token_ids_to_penalize))

    if logits_to_penalize:
        # use zip and * to group indices along each dimension
        # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
        logits[tuple(zip(*logits_to_penalize))] = -float("inf")

395
    # verifies that no rows in logits were missed unexpectedly
396
    assert logits_applied == logits.shape[0]
397
398
399
    return logits


400
def _apply_top_k_top_p(
401
    logits: torch.Tensor,
402
403
    p: torch.Tensor,
    k: torch.Tensor,
404
) -> torch.Tensor:
405
406
407
408
409
410
411
412
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

    # Apply top-k.
    top_k_mask = logits_sort.size(1) - k.to(torch.long)
    # Get all the top_k values.
    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
    top_k_mask = logits_sort < top_k_mask
    logits_sort.masked_fill_(top_k_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414

    # Apply top-p.
415
    probs_sort = logits_sort.softmax(dim=-1)
416
417
418
419
420
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
    # at least one
    top_p_mask[:, -1] = False
    logits_sort.masked_fill_(top_p_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
421
422

    # Re-sort the probabilities.
423
424
425
    logits = torch.empty_like(logits_sort).scatter_(dim=-1,
                                                    index=logits_idx,
                                                    src=logits_sort)
426
    return logits
427
428


Roy's avatar
Roy committed
429
430
def _apply_min_p(
    logits: torch.Tensor,
431
    min_p: torch.Tensor,
Roy's avatar
Roy committed
432
433
434
435
436
437
438
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
439
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
440
    tokens_to_remove = probs < scaled_min_p
441
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
442
443
444
445

    return logits


446
def _greedy_sample(
447
    selected_seq_groups: List[SequenceGroupToSample],
448
    samples: torch.Tensor,
449
) -> SampleResultType:
450
451
452
453
454
455
456
457
458
459
460
461
    """Run greedy sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        samples: (num_selected_samples,) A tensor of samples. The length of
            samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
462
    samples_lst = samples.tolist()
463
    sample_idx = 0
464
    results: SampleResultType = []
465
    for seq_group in selected_seq_groups:
466
467
468
469
470
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
471
472
473
474
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
475
        next_token_ids = [samples_lst[sample_idx]]
476
477
478
479
480
481
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
482
    selected_seq_groups: List[SequenceGroupToSample],
483
    random_samples: torch.Tensor,
484
) -> SampleResultType:
485
486
487
488
489
490
491
492
493
494
495
496
    """Run random sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        random_samples: (num_selected_samples,) A tensor of samples. The
            length of samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
497
    # Find the maximum n value of the prompt phase requests.
498
    random_samples = random_samples.cpu()
499
    sample_idx = 0
500
    results: SampleResultType = []
501
502
503
504
505
506
507
508
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params
        is_prompt = seq_group.is_prompt
509
510
511
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
512
            parent_ids = [0] * sampling_params.n
513
            next_token_ids = random_samples[
514
                sample_idx, :sampling_params.n].tolist()
515
516
517
518
519
520
521
522
523
524
525
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _beam_search_sample(
526
    selected_seq_groups: List[SequenceGroupToSample],
527
    logprobs: torch.Tensor,
528
) -> SampleResultType:
529
530
531
532
533
534
535
536
537
538
539
    """Run beam sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
        on selected sample indices.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
540
541
542
543
544
545
546
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
547
    # NOTE: Beam search is not vectorized, so its speed can be slower than
548
549
    # other sampling methods.
    sample_idx = 0
550
    results: SampleResultType = []
551
552
553
554
555
556
557
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        is_prompt = seq_group.is_prompt
        seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
558
        num_parent_seqs = len(seq_ids)
559
        beam_width = sampling_params.n
560
561
562
563
564
565
566
567
568
569
570
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
571
            cumulative_logprobs: List[float] = [
572
573
                seq_group.seq_data[seq_id].cumulative_logprob
                for seq_id in seq_ids
574
            ]
575
            cumulative_logprobs_tensor = torch.tensor(
576
577
578
579
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
580
                                  cumulative_logprobs_tensor.unsqueeze(dim=1))
581
582
583
584
585
586
587
588
589
590
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
591
592


593
594
595
596
597
598
599
600
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
601
    seq_groups: Optional[List[SequenceGroupToSample]] = None,
Nick Hill's avatar
Nick Hill committed
602
) -> torch.Tensor:
603
    if num_samples > 1:
604
        probs = probs.repeat_interleave(num_samples, dim=0)
Nick Hill's avatar
Nick Hill committed
605
606
607
608
609
    q = torch.empty_like(probs)
    if seq_groups is None:
        q.exponential_()
    else:
        sample_idx = 0
610
611
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
612
613
614
615
616
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            q[sample_idx:sample_idx +
              stride].exponential_(generator=seq_group.generator)
            sample_idx += stride
617
618
619
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
def _top_k_top_p_multinomial_with_flashinfer(
        probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
        num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
    max_top_k_round = 32
    if num_samples > 1:
        probs = probs.repeat_interleave(num_samples, dim=0)
        top_ks = top_ks.repeat_interleave(num_samples)
        top_ps = top_ps.repeat_interleave(num_samples)
    batch_size = probs.shape[0]
    uniform_samples = torch.empty((max_top_k_round, batch_size),
                                  device=probs.device)
    if seq_groups is None:
        uniform_samples.uniform_()
    else:
        sample_idx = 0
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            uniform_samples[:, sample_idx:sample_idx +
                            stride].uniform_(generator=seq_group.generator)
            sample_idx += stride
    batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
        probs,
        uniform_samples,
        top_ks,
        top_ps,
    )
    if not success.all():
        warnings.warn("FlashInfer rejection sampling failed, fallback.",
                      stacklevel=1)
        probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
        probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
        batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
            probs, uniform_samples[0])
    return batch_next_token_ids.view(-1, num_samples)


658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
def get_pythonized_sample_results(
        sample_result_args: SampleResultArgsType) -> SampleResultType:
    '''This function consumes GPU-side sampler results and computes
    Pythonized CPU-side sampler results (GPU -> CPU sync.)

    Single-step scheduling: this function is invoked at sampling-time
    for immediate Pythonization.

    Multi-step scheduling: Pythonization is deferred until after multiple
    GPU-side steps have been completed.

    Args:
      sample_result_args: GPU-side inputs to the Pythonization process

    Returns:
      Pythonized sampler results
    '''

    (
        sample_metadata,
        sampling_metadata,
        greedy_samples,
        multinomial_samples,
        beam_search_logprobs,
        sample_results_dict,
    ) = (
        sample_result_args.sample_metadata,
        sample_result_args.sampling_metadata,
        sample_result_args.greedy_samples,
        sample_result_args.multinomial_samples,
        sample_result_args.beam_search_logprobs,
        sample_result_args.sample_results_dict,
    )

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
        (seq_group_id, seq_groups) = sample_metadata[sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, greedy_samples)
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
            sample_results = _random_sample(seq_groups,
                                            multinomial_samples[sampling_type])
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups,
                                                 beam_search_logprobs)
        sample_results_dict.update(zip(seq_group_id, sample_results))

    return [
        sample_results_dict.get(i, ([], []))
        for i in range(len(sampling_metadata.seq_groups))
    ]


712
def _sample_with_torch(
713
714
    probs: torch.Tensor,
    logprobs: torch.Tensor,
715
    sampling_metadata: SamplingMetadata,
716
    sampling_tensors: SamplingTensors,
717
718
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
719
720
721
) -> SampleReturnType:
    '''Torch-oriented _sample() implementation.

722
    Single-step scheduling:
723
724
725
726
727
728
729
730
731
    * Perform GPU-side sampling computation
    * Immediately Pythonize sampling result

    Multi-step scheduling:
    * Perform GPU-side sampling computation
    * Defer Pythonization & preserve GPU-side
      tensors required for Pythonization
    '''

732
733
734
735
    categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
        t: []
        for t in SamplingType
    }
736
737
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
738
        sampling_params = seq_group.sampling_params
739
740
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
741

742
743
744
745
746
    sample_results_dict: SampleResultsDictType = {}
    sample_metadata: SampleMetadataType = {}
    multinomial_samples: MultinomialSamplesType = {}
    greedy_samples: Optional[torch.Tensor] = None
    beam_search_logprobs: Optional[torch.Tensor] = None
747

748
749
    # Create output tensor for sampled token ids.
    if include_gpu_probs_tensor:
750
751
752
753
        sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
                                              VLLM_INVALID_TOKEN_ID,
                                              dtype=torch.long,
                                              device=logprobs.device)
754
755
756
    else:
        sampled_token_ids_tensor = None

757
758
    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
759
    for sampling_type in SamplingType:
760
        sample_indices = categorized_sample_indices[sampling_type]
761
        num_tokens = len(sample_indices)
762
763
        if num_tokens == 0:
            continue
764

765
766
767
768
        seq_group_id = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
        sample_metadata[sampling_type] = (seq_group_id, seq_groups)
        long_sample_indices = sample_indices.long()
769
        if sampling_type == SamplingType.GREEDY:
770
            greedy_samples = torch.argmax(logprobs[long_sample_indices],
771
                                          dim=-1)
772

773
            if sampled_token_ids_tensor is not None:
774
775
776
777
778
779
780
781
782
783
784
785
                # Store sampled tokens in output tensor.
                sampled_token_ids_tensor[
                    long_sample_indices] = greedy_samples.unsqueeze(-1)

            if modify_greedy_probs:
                # If required, modify the probabilities such that sampling from
                # the modified distribution would always sample the argmax
                # token id.
                _modify_greedy_probs_inplace(logprobs, probs,
                                             long_sample_indices,
                                             greedy_samples)

Nick Hill's avatar
Nick Hill committed
786
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
787
            max_n_in_batch = 1
788
789
790
            for seq_group in seq_groups:
                if seq_group.is_prompt:
                    sampling_params = seq_group.sampling_params
791
                    max_n_in_batch = max(max_n_in_batch, sampling_params.n)
792
793
794
795
796
797
798
799
800
            seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
                              seq_groups)

            if flashinfer_top_k_top_p_sampling is not None:
                multinomial_samples[
                    sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
                        probs[long_sample_indices],
                        sampling_tensors.top_ks[long_sample_indices],
                        sampling_tensors.top_ps[long_sample_indices],
801
                        max_n_in_batch,
802
803
804
805
806
                        seq_groups_arg,
                    )
            else:
                multinomial_samples[sampling_type] = _multinomial(
                    probs[long_sample_indices],
807
                    max_n_in_batch,
808
                    seq_groups=seq_groups_arg)
809

810
            if sampled_token_ids_tensor is not None:
811
                # Store sampled tokens in output tensor.
812
813
                sampled_token_ids_tensor[long_sample_indices] = \
                    multinomial_samples[sampling_type].to(torch.long)
814

815
816
817
818
819
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

820
821
822
823
824
825
826
827
828
829
    # Encapsulate arguments for computing Pythonized sampler
    # results, whether deferred or otherwise.
    maybe_deferred_args = SampleResultArgsType(
        sampling_metadata=sampling_metadata,
        sample_metadata=sample_metadata,
        multinomial_samples=multinomial_samples,
        greedy_samples=greedy_samples,
        beam_search_logprobs=beam_search_logprobs,
        sample_results_dict=sample_results_dict)

830
    if not sampling_metadata.skip_sampler_cpu_output:
831
832
833
834
835
        # GPU<->CPU sync happens here.
        # This also converts the sampler output to a Python object.
        # Return Pythonized sampler result & sampled token ids
        return get_pythonized_sample_results(
            maybe_deferred_args), sampled_token_ids_tensor
836
    else:
837
838
839
840
841
842
        # Defer sampler result Pythonization; return deferred
        # Pythonization args & sampled token ids
        return (
            maybe_deferred_args,
            sampled_token_ids_tensor,
        )
843
844


845
def _sample(
846
847
848
849
850
851
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
852
) -> SampleReturnType:
853
854
855
856
857
858
859
860
861
862
    """
    Args:
        probs: (num_query_tokens_in_batch, num_vocab)
        logprobs: (num_query_tokens_in_batch, num_vocab)
        sampling_metadata: The metadata for a batch for sampling.
        sampling_tensors: Tensors that include sampling related metadata.

    Returns:
        (next_token_ids, parent_seq_ids) for each seq group in a batch.
            If sampling is skipped, it returns ([], [])
863
        sampled_token_ids_tensor: A tensor of sampled token ids.
864
    """
865
866
867
868
    return _sample_with_torch(
        probs,
        logprobs,
        sampling_metadata,
869
        sampling_tensors,
870
871
872
        include_gpu_probs_tensor=include_gpu_probs_tensor,
        modify_greedy_probs=modify_greedy_probs,
    )
873
874


875
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
876
877
878
879
880
881
    """
    This function calculates the ranks of the chosen tokens in a logprob tensor.

    Args:
        x (torch.Tensor): 2D logprob tensor of shape (N, M)
                        where N is the no. of tokens and M is the vocab dim.
882
        indices (torch.Tensor): List of chosen token indices.
883
884
885

    Returns:
        torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
886
                    Each element in the returned tensor represents the rank
887
888
                    of the chosen token in the input logprob tensor.
    """
889
890
    vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
             indices]
891
892
893
    result = (x > vals[:, None])
    del vals
    return result.sum(1).add_(1)
894
895


896
def get_logprobs(
897
    logprobs: torch.Tensor,
898
    sampling_metadata: SamplingMetadata,
899
    sample_results: SampleResultType,
900
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
901
    """Return sample logprobs and prompt logprobs.
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932

    The logic consists of 3 parts.
    - Select indices to compute logprob from, ranks of token ids, and
        the top k token ids from logprobs.
    - Compute prompt logprobs if required.
    - Compute sample logprobs if required.

    Args:
        logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
            logprob per vocab. Sequence groups' query tokens are batched in a
            single flattened tensor. For example, assuming there are N
            seq groups, it is sorted by prefill tokens for seq_group_1 (if
            prompt logprob is enabled), decode tokens for seq_group_1 (if
            sampling is required), prefill tokens for seq_group_2, ...
        sampling_metadata: The sampling metadata.
        sample_results: (num_seq_groups) The tuple of (next_token_ids,
            parent_ids) for each sequence group. When beam search is enabled,
            sample_results can contain different number of seq_ids from
            sampling_metadata.seq_groups. It is because beam search creates
            2 * BEAM_WIDTH number of samples (whereas there are only up to
            BEAM_WIDTH number of seq_ids).

    Returns:
        A tuple of prompt and sample logprobs per sequence group in a batch.
    """
    # The index of query token to calculate logprobs. It includes both
    # prompt and sample logprob indices.
    query_indices: List[int] = []
    # The next token ids to get the logprob value from.
    next_token_ids: List[int] = []
    # The largest requested number of logprobs. We find logprobs as many as the
933
934
935
    # largest num logprobs in this API. If every logprobs is None, it will be
    # set to -1.
    largest_num_logprobs = -1
936
937
938
939
940
941
942
943
944

    # Select indices to compute logprob from, ranks of token ids, and the top
    # k token ids from logprobs.
    for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
                                          sample_results):
        sampling_params = seq_group.sampling_params

        # Update indices and tokens for prompt logprobs.
        if (seq_group.is_prompt
945
946
947
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
            next_prompt_tokens = _get_next_prompt_tokens(seq_group)
            query_indices.extend(seq_group.prompt_logprob_indices)
            next_token_ids.extend(next_prompt_tokens)

        # Update indices and next tokenes for sample logprob.
        if seq_group.do_sample:
            token_ids, parent_seq_ids = sample_result
            # NOTE: We cannot directly use sample_indices because
            # sample_indices only contain parent seq_ids of a previous step.
            # The current step may have different number of seq_ids, and
            # we can obtain it from `sample_result[1]`.
            query_idx = seq_group.sample_indices[0]
            query_indices.extend(
                [query_idx + parent_id for parent_id in parent_seq_ids])
            next_token_ids.extend(token_ids)

            if sampling_params.logprobs is not None:
                largest_num_logprobs = max(largest_num_logprobs,
                                           sampling_params.logprobs)

        assert len(next_token_ids) == len(query_indices)

    if len(query_indices) == 0:
971
972
        empty_sampled_logprob: SampleLogprobs = []
        empty_prompt_logprob: Optional[PromptLogprobs] = None
973
974
        return [empty_prompt_logprob], [empty_sampled_logprob]

975
976
977
978
979
    selected_logprobs, ranks = None, None
    top_logprobs, top_token_ids = None, None

    # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
    # skip the whole logprob calculation.
980
    if largest_num_logprobs >= 0:
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
        next_token_ids_gpu = torch.tensor(next_token_ids,
                                          device=logprobs.device)

        # (num_selected_query_tokens, num_logprobs). Note that query_indices can
        # contain duplicates if beam search is enabled.
        selected_logprobs = logprobs[[
            query_indices_gpu,
            next_token_ids_gpu,
        ]]
        ranks = _get_ranks(
            logprobs[query_indices_gpu],
            next_token_ids_gpu,
        )
        assert selected_logprobs.shape[0] == ranks.shape[0]

        # We need to compute top k only if there exists logprobs > 0.
        if largest_num_logprobs > 0:
            # Logprobs of topk tokens for a batch of sequence groups.
            # (num_query_tokens_across_batch).
            top_logprobs, top_token_ids = torch.topk(logprobs,
                                                     largest_num_logprobs,
                                                     dim=-1)
            top_logprobs = top_logprobs.to('cpu')
            top_token_ids = top_token_ids.to('cpu')
1006

1007
1008
        selected_logprobs = selected_logprobs.to('cpu')
        ranks = ranks.to('cpu')
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047

    # Find prompt/sample logprobs.
    prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
    sample_logprobs_per_seq_group: List[SampleLogprobs] = []
    top_logprob_idx = 0
    selected_logprobs_idx = 0

    for seq_group, sample_result in zip(sampling_metadata.seq_groups,
                                        sample_results):
        (prompt_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_prompt_logprob_if_needed(
             seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
             selected_logprobs_idx, top_logprob_idx)
        prompt_logprobs_per_seq_group.append(prompt_logprobs)

        (sampled_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_sampled_logprob_if_needed(
             seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
             top_logprobs, selected_logprobs_idx, top_logprob_idx)
        sample_logprobs_per_seq_group.append(sampled_logprobs)

    return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group


def _get_prompt_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the prompt logprob from a sequence group if needed."""
    sampling_params = seq_group.sampling_params
    is_prompt = seq_group.is_prompt

    # Find prompt logprobs
    prompt_logprobs: Optional[PromptLogprobs] = None
1048
    if is_prompt and sampling_params.prompt_logprobs is not None:
1049
1050
1051
        prompt_logprobs = []
        num_logprobs = sampling_params.prompt_logprobs
        next_prompt_tokens = _get_next_prompt_tokens(seq_group)
1052
1053
1054
1055
1056
1057
1058
1059
1060
        # Pre-select indexes and create a list. It is faster than calling .item
        # repetitively.
        selected_logprob_items = selected_logprobs[
            selected_logprobs_idx:selected_logprobs_idx +
            len(next_prompt_tokens)].tolist()
        rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                           len(next_prompt_tokens)].tolist()

        for idx, token_id in enumerate(next_prompt_tokens):
1061
1062
1063
            # Calculate the prompt logprob of the real prompt tokens.
            # {token_id: (logprob, rank_from_vocab)}
            prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
1064
                token_id: (selected_logprob_items[idx], rank_items[idx])
1065
            }
1066

1067
1068
            # Add top K prompt logprobs along with its rank.
            if num_logprobs > 0:
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
                top_ids = top_token_ids[
                    top_logprob_idx, :num_logprobs].tolist()
                top_probs = top_logprobs[
                    top_logprob_idx, :num_logprobs].tolist()
                # Top K is already sorted by rank, so we can use 1 ~
                # num_logprobs + 1 for rank.
                top_ranks = range(1, num_logprobs + 1)
                prompt_logprobs_dict.update({
                    top_id: (top_prob, rank)
                    for top_id, top_prob, rank in zip(top_ids, top_probs,
                                                      top_ranks)
                })
1081
1082
1083
1084
1085
1086
            prompt_logprobs.append({
                token_id: Logprob(*logprob_and_rank)
                for token_id, logprob_and_rank in prompt_logprobs_dict.items()
            })
            # + 1 to go to the next prompt token.
            top_logprob_idx += 1
1087
1088
1089

        # + len(next_prompt_tokens) to go to the next prompt.
        selected_logprobs_idx += len(next_prompt_tokens)
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    return prompt_logprobs, top_logprob_idx, selected_logprobs_idx


def _get_sampled_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    sample_result: Tuple[List[int], List[int]],
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the sample logprob if needed."""
    seq_ids = seq_group.seq_ids
1105
    num_logprobs = seq_group.sampling_params.logprobs
1106
1107
1108
1109
1110
    sampled_logprobs: SampleLogprobs = []
    next_token_ids, parent_seq_ids = sample_result

    if seq_group.do_sample:
        assert len(next_token_ids) > 0
1111
        if num_logprobs is None:
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
            for next_token_id in next_token_ids:
                # Use a dummy logprob
                sampled_logprobs.append({next_token_id: Logprob(inf)})
        else:
            # Pre-select items from tensor. tolist() is faster than repetitive
            # `.item()` calls.
            selected_logprob_items = selected_logprobs[
                selected_logprobs_idx:selected_logprobs_idx +
                len(next_token_ids)].tolist()
            rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                               len(next_token_ids)].tolist()
            for idx, (next_token_id, parent_id) in enumerate(
                    zip(next_token_ids, parent_seq_ids)):
                # Get the logprob of a sampled token.
                sampled_logprobs_dict = {
                    next_token_id:
                    (selected_logprob_items[idx], rank_items[idx])
                }
                if num_logprobs is not None and num_logprobs > 0:
                    # Get top K logprobs.
                    top_ids = top_token_ids[top_logprob_idx +
                                            parent_id, :num_logprobs].tolist()
                    top_probs = top_logprobs[
                        top_logprob_idx + parent_id, :num_logprobs].tolist()
                    # Top K is already sorted by rank, so we can use 1 ~
                    # num_logprobs + 1 for rank.
                    top_ranks = range(1, num_logprobs + 1)
                    sampled_logprobs_dict.update({
                        top_id: (top_prob, rank)
                        for top_id, top_prob, rank in zip(
                            top_ids, top_probs, top_ranks)
                    })

                sampled_logprobs.append({
                    token_id: Logprob(*logprob_and_rank)
                    for token_id, logprob_and_rank in
                    sampled_logprobs_dict.items()
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
                })

        # NOTE: This part of code is not intuitive. `selected_logprobs` include
        # logprobs for the current step, which has len(next_token_ids) tokens
        # per sequence group. `logprobs` includes logprobs from the previous
        # steps, which has len(seq_ids) tokens per sequence group.

        # Iterate to the next sequence group in a batch.
        selected_logprobs_idx += len(next_token_ids)
        # Iterate to the next sequence group in a batch.
1159
1160
        top_logprob_idx += len(seq_ids)
    return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
1161
1162


1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
                                 sample_indices: torch.Tensor,
                                 greedy_samples: torch.Tensor) -> None:
    """Modify the probability distributions of the greedily-sampled tokens such
    that each sampled token has a "probability" of 1.0. This is required by
    speculative decoding, which depends on the sampling method being encoded
    within the probability distribution for correctness.

    # Why do we only need to do this for greedy sampling?

    vLLM's sampler performs the following steps for greedy or multinomial
    (random) sampling:
        1. Get logits from model.
        2. Modify logits according to per-sequence sampling parameters.
            - Multiply by temperature, top-k and top-p masking, penalize tokens
                according to their frequency, etc.
        3. Sample a token.
            - Random sampling simply samples from the modified probability
                distribution.
            - Greedy sampling performs `argmax` to obtain the token with the
                highest likelihood.
1184

1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    Ignoring greedy sampling for a moment, we find that the computed probability
    distribution has the following property: we can sample from it independently
    and find that the token sampled by the Sampler has a frequency corresponding
    to how often we see it in our sampling. In other words, for tokens sampled
    with vLLM's random SamplingType, the computed probability distribution
    encodes the sampling methodology completely.

    Greedy sampling does not normally have this property. vLLM modifies logits
    according to sampling params, then performs `argmax`, then returns the
    sampled token and the computed probability distribution. If we sample from
    the distribution, we'll find the likelihood of the greedily-sampled token
    is not always 1.0.

    Since lossless speculative decoding requires that the sampling methodology
    be encoded within the probability distribution, we are motivated to modify
    the probability distribution such that the sampled token has probability 1
    when speculative decoding is used.

    NOTE: Alternatively, we could use an extremely low temperature to achieve
    greedy sampling using multinomial computation and unite the codepaths. This
    has implications on the overall design of the sampler, e.g. how to record
    accurate logprobs for the user, so this improvement is deferred to later.
    """
1208
    # NOTE: logprobs are not modified so they can be returned to the user.
1209
1210
1211
1212
    probs[sample_indices, :] = 0
    probs[sample_indices, greedy_samples] = 1.0


1213
def _build_sampler_output(
1214
    maybe_deferred_sample_results: MaybeDeferredSampleResultType,
1215
    sampling_metadata: SamplingMetadata,
1216
1217
    prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
    sample_logprobs: Optional[List[SampleLogprobs]],
1218
1219
    on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
                                      torch.Tensor]],
1220
    skip_sampler_cpu_output: bool = False,
1221
    logits: Optional[torch.Tensor] = None
1222
) -> SamplerOutput:
1223
1224
1225
1226
1227
1228
1229
1230
    """Construct Python objects with the output of sampling.

    Args:
        on_device_tensors: Tuple containing on-device tensors with the
            probabilities used in sampling and the sampled token ids. This
            allows post-processing without copies to CPU/serialization, e.g. in
            speculative decoding rejection sampling.
    """
1231
    sampler_output: List[CompletionSequenceGroupOutput] = []
1232
1233
1234
1235
1236

    if skip_sampler_cpu_output:
        assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
        deferred_sample_results_args = maybe_deferred_sample_results
    else:
1237
1238
        assert prompt_logprobs is not None
        assert sample_logprobs is not None
1239
1240
1241
        assert not isinstance(maybe_deferred_sample_results,
                              SampleResultArgsType)
        deferred_sample_results_args = None
1242
1243
1244

        for (seq_group, sample_result, group_prompt_logprobs,
             group_sample_logprobs) in zip(sampling_metadata.seq_groups,
1245
1246
                                           maybe_deferred_sample_results,
                                           prompt_logprobs, sample_logprobs):
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
            seq_ids = seq_group.seq_ids
            next_token_ids, parent_ids = sample_result
            seq_outputs: List[SequenceOutput] = []
            for parent_id, next_token_id, logprobs in zip(
                    parent_ids, next_token_ids, group_sample_logprobs):
                seq_outputs.append(
                    SequenceOutput(seq_ids[parent_id], next_token_id,
                                   logprobs))
            sampler_output.append(
                CompletionSequenceGroupOutput(seq_outputs,
                                              group_prompt_logprobs))
1258
1259
1260

    # If not specified, store None values in SamplerOutput.
    if on_device_tensors is not None:
1261
1262
        (sampled_token_probs, logprobs_tensor,
         sampled_token_ids) = on_device_tensors
1263
    else:
1264
1265
        sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
                                                                   None)
1266
1267
1268
1269
1270

    return SamplerOutput(
        outputs=sampler_output,
        sampled_token_probs=sampled_token_probs,
        sampled_token_ids=sampled_token_ids,
1271
        logprobs=logprobs_tensor,
1272
1273
        deferred_sample_results_args=deferred_sample_results_args,
        logits=logits)
1274
1275


1276
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
    """Get a list of next prompt tokens to compute logprob from a
        given sequence group.

    It is used to compute prompt logprob. Imagine you have logprob for each
    query token. Query token needs to know the next prompt token id to compute
    prompt logprob. This is a helper to obtain next prompt token ids.

    This API has to be used only when the caller knows seq_group is in prefill
    stage.

    Returns:
        A list of next prompt tokens to compute logprob.
    """
    assert seq_group.is_prompt, (
        "Caller should ensure the sequence group is in a prefill stage.")
    seq_ids = seq_group.seq_ids
1293
1294
    query_len = seq_group.query_len
    assert query_len is not None
1295
1296
1297
1298
1299
1300
1301
    # prompt has only 1 seq id.
    assert len(seq_ids) == 1
    seq_data = seq_group.seq_data[seq_ids[0]]
    computed_len = seq_data.get_num_computed_tokens()
    prompt_tokens = seq_data.prompt_token_ids
    # +1 because we are looking for a next prompt token.
    next_token_index_start = computed_len + 1
1302
    next_token_index_end = min(computed_len + query_len + 1,
1303
1304
1305
1306
                               len(prompt_tokens))
    next_prompt_tokens = prompt_tokens[
        next_token_index_start:next_token_index_end]
    return next_prompt_tokens