sampler.py 50.6 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
import itertools
3
4
import warnings
from importlib.util import find_spec
5
from math import inf
6
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10

import torch
import torch.nn as nn

11
12
13
14
15
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
    from vllm.model_executor.layers.ops.sample import sample as sample_triton

16
import vllm.envs as envs
17
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
18
19
20
                                                   SamplingTensors,
                                                   SequenceGroupToSample)
from vllm.sampling_params import SamplingType
21
22
23
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
                           PromptLogprobs, SampleLogprobs, SamplerOutput,
                           SequenceOutput)
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
26
27
28
29
30
31
32
33
34
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
    import flashinfer.sampling
    # yapf: disable
    from flashinfer.sampling import (
        top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

    # yapf: enable
else:
    flashinfer_top_k_top_p_sampling = None

35
36
37
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

38

Woosuk Kwon's avatar
Woosuk Kwon committed
39
class Sampler(nn.Module):
40
41
42
43
44
45
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
46
    3. Apply presence, frequency and repetition penalties.
47
48
49
50
51
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
52
53
54
55
56
57

    The structure of the logits tensor is coupled with the seq_groups in
    sampling_metadata. Typically, each sequence in each seq_group has one row in
    logits for the next token to be sampled; however, for a seq_group with a
    prompt request with the prompt_logprobs sampling parameter, there are rows
    in logits for each token in the input prompt.
58
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
61
62
63
64
65
66
    def __init__(self):
        super().__init__()

        # Whether or not the SamplerOutput should have on-device tensors
        # containing the sampled token ids and probabilities. This is used by
        # speculative decoding.
        self.include_gpu_probs_tensor = False
67
        self.should_modify_greedy_probs_inplace = False
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def _init_sampling_tensors(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ):
        """The goal here is to reuse sampling tensors between similar decode
        runs. This is possible because sampling logic does not change between
        decodes of the same sequences.
        """
        _, vocab_size = logits.shape

        # First free any existing stored sampling tensors.
        # This is necessary because some sampling tensors may
        # have pinned memory.
        self._sampling_tensors = None

        # Initialize new sampling tensors
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)

        self._sampling_tensors = sampling_tensors
        self._do_penalties = do_penalties
        self._do_top_p_top_k = do_top_p_top_k
        self._do_min_p = do_min_p

Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
    def forward(
        self,
97
        logits: torch.Tensor,
98
        sampling_metadata: SamplingMetadata,
99
    ) -> Optional[SamplerOutput]:
100
101
102
103
104
        """
        Args:
            logits: (num_tokens, vocab_size).
            sampling_metadata: Metadata for sampling.
        """
105
        assert logits is not None
106
107
        _, vocab_size = logits.shape

108
        # Prepare sampling tensors with pinned memory to avoid blocking.
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        if not sampling_metadata.reuse_sampling_tensors:
            self._init_sampling_tensors(logits, sampling_metadata)
        elif self._do_penalties:
            # In this case, the sampling tensors logic depends on
            # "output_tokens" of a sequence. As a result, we cannot
            # reuse sampling tensors, since "output_tokens" changes
            # between decode runs.
            self._init_sampling_tensors(logits, sampling_metadata)

        assert self._sampling_tensors is not None
        sampling_tensors = self._sampling_tensors
        do_penalties = self._do_penalties
        do_top_p_top_k = self._do_top_p_top_k
        do_min_p = self._do_min_p

        logits = _apply_min_tokens_penalty(logits, sampling_metadata)
125

126
        # Apply presence and frequency penalties.
127
128
129
130
131
132
        if do_penalties:
            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                      sampling_tensors.output_tokens,
                                      sampling_tensors.presence_penalties,
                                      sampling_tensors.frequency_penalties,
                                      sampling_tensors.repetition_penalties)
133

134
        # Use float32 to apply temperature scaling.
135
        # Use in-place division to avoid creating a new tensor.
136
        logits = logits.to(torch.float)
137
        logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
138

139
        if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
140
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
141
142
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
143
        if do_min_p:
144
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
145

146
147
148
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
149
150
        # Compute the log probabilities.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
151

Woosuk Kwon's avatar
Woosuk Kwon committed
152
        # Sample the next tokens.
153
154
155
156
157
158
159
160
161
162
163
        sample_results, maybe_sampled_tokens_tensor = _sample(
            probs,
            logprobs,
            sampling_metadata,
            sampling_tensors,
            include_gpu_probs_tensor=self.include_gpu_probs_tensor,
            modify_greedy_probs=self._should_modify_greedy_probs_inplace,
        )

        if self.include_gpu_probs_tensor:
            assert maybe_sampled_tokens_tensor is not None
164
            on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
165
166
167
        else:
            on_device_tensors = None

168
        # Get the logprobs query results.
169
170
171
172
173
174
175
176
177
178
179
180
181
        prompt_logprobs = None
        sample_logprobs = None
        if not sampling_metadata.skip_sampler_cpu_output:
            prompt_logprobs, sample_logprobs = _get_logprobs(
                logprobs, sampling_metadata, sample_results)

        return _build_sampler_output(
            sample_results,
            sampling_metadata,
            prompt_logprobs,
            sample_logprobs,
            on_device_tensors=on_device_tensors,
            skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
182
183
184
185
186
187
188
189
190
191
192
193
194

    @property
    def _should_modify_greedy_probs_inplace(self) -> bool:
        """Whether or not the sampler should modify the probability distribution
        of greedily-sampled tokens such that multinomial sampling would sample
        the greedily-sampled token.

        In other words, if True then we set the probability of the greedily-
        sampled token to 1.

        This is used by speculative decoding, which requires that the sampling
        method be encoded into the probability distribution.
        """
195
        return self.should_modify_greedy_probs_inplace
196
197


198
def _get_bin_counts_and_mask(
199
    tokens: torch.Tensor,
200
201
202
203
204
205
206
    vocab_size: int,
    num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
207
208
                             device=tokens.device)
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
209
210
211
212
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask
213
214


215
216
217
218
def _apply_min_tokens_penalty(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
219
220
221
    """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
        have not been generated yet
    """
222
    # list of indices in logits that will be set to -inf
223
    logits_to_penalize: List[Tuple[int, int]] = []
224
225
226
227
228
229
230
231
232
233
    logits_applied = 0
    for seq_group in sampling_metadata.seq_groups:
        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params

        sample_indices = seq_group.sample_indices
        logits_applied += len(sample_indices) + len(
            seq_group.prompt_logprob_indices)
        if not seq_group.do_sample:
            continue
234

235
        start_idx = sample_indices[0]
236
        min_tokens = sampling_params.min_tokens
237
238
        token_ids_to_penalize = sampling_params.all_stop_token_ids
        if min_tokens > 0 and token_ids_to_penalize:
239
            seqs_to_penalize: List[int] = []
240
            for j, seq_id in enumerate(seq_ids):
241
                seq_data = seq_group.seq_data[seq_id]
242
                if len(seq_data.output_token_ids_array) < min_tokens:
243
                    seqs_to_penalize.append(j)
244
245
246

            if seqs_to_penalize:
                # convert to the index into logits
247
                seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
248
249
250
251
252
253
254
255
256
                # itertools.product pairs each seq index with every token id
                logits_to_penalize.extend(
                    itertools.product(seqs_to_penalize, token_ids_to_penalize))

    if logits_to_penalize:
        # use zip and * to group indices along each dimension
        # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
        logits[tuple(zip(*logits_to_penalize))] = -float("inf")

257
    # verifies that no rows in logits were missed unexpectedly
258
    assert logits_applied == logits.shape[0]
259
260
261
    return logits


262
263
264
265
266
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                     output_tokens_tensor: torch.Tensor,
                     presence_penalties: torch.Tensor,
                     frequency_penalties: torch.Tensor,
                     repetition_penalties: torch.Tensor) -> torch.Tensor:
267
    num_seqs, vocab_size = logits.shape
268
269
    _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
                                              num_seqs)
270
    output_bin_counts, output_mask = _get_bin_counts_and_mask(
271
        output_tokens_tensor, vocab_size, num_seqs)
272

ljss's avatar
ljss committed
273
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
274
    repetition_penalties[~(prompt_mask | output_mask)] = 1.0
ljss's avatar
ljss committed
275
276
277
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

278
279
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
280
281
    logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
282
283
284
    return logits


285
def _apply_top_k_top_p(
286
    logits: torch.Tensor,
287
288
    p: torch.Tensor,
    k: torch.Tensor,
289
) -> torch.Tensor:
290
291
292
293
294
295
296
297
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

    # Apply top-k.
    top_k_mask = logits_sort.size(1) - k.to(torch.long)
    # Get all the top_k values.
    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
    top_k_mask = logits_sort < top_k_mask
    logits_sort.masked_fill_(top_k_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299

    # Apply top-p.
300
    probs_sort = logits_sort.softmax(dim=-1)
301
302
303
304
305
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
    # at least one
    top_p_mask[:, -1] = False
    logits_sort.masked_fill_(top_p_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307

    # Re-sort the probabilities.
308
309
310
311
312
313
    src = torch.arange(logits_idx.shape[-1],
                       device=logits_idx.device).expand_as(logits_idx)
    logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
                                                           index=logits_idx,
                                                           src=src)
    logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
314
    return logits
315
316


Roy's avatar
Roy committed
317
318
def _apply_min_p(
    logits: torch.Tensor,
319
    min_p: torch.Tensor,
Roy's avatar
Roy committed
320
321
322
323
324
325
326
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
327
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
328
    tokens_to_remove = probs < scaled_min_p
329
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
330
331
332
333

    return logits


334
def _greedy_sample(
335
    selected_seq_groups: List[SequenceGroupToSample],
336
    samples: torch.Tensor,
337
) -> SampleResultType:
338
339
340
341
342
343
344
345
346
347
348
349
    """Run greedy sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        samples: (num_selected_samples,) A tensor of samples. The length of
            samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
350
    samples_lst = samples.tolist()
351
    sample_idx = 0
352
    results: SampleResultType = []
353
    for seq_group in selected_seq_groups:
354
355
356
357
358
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
359
360
361
362
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
363
        next_token_ids = [samples_lst[sample_idx]]
364
365
366
367
368
369
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
370
    selected_seq_groups: List[SequenceGroupToSample],
371
    random_samples: torch.Tensor,
372
) -> SampleResultType:
373
374
375
376
377
378
379
380
381
382
383
384
    """Run random sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        random_samples: (num_selected_samples,) A tensor of samples. The
            length of samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
385
    # Find the maximum best_of value of the prompt phase requests.
386
    random_samples = random_samples.cpu()
387
    sample_idx = 0
388
    results: SampleResultType = []
389
390
391
392
393
394
395
396
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params
        is_prompt = seq_group.is_prompt
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _beam_search_sample(
414
    selected_seq_groups: List[SequenceGroupToSample],
415
    logprobs: torch.Tensor,
416
) -> SampleResultType:
417
418
419
420
421
422
423
424
425
426
427
    """Run beam sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
        on selected sample indices.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
428
429
430
431
432
433
434
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
435
    # NOTE: Beam search is not vectorized, so its speed can be slower than
436
437
    # other sampling methods.
    sample_idx = 0
438
    results: SampleResultType = []
439
440
441
442
443
444
445
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        is_prompt = seq_group.is_prompt
        seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
446
447
448
449
450
451
452
453
454
455
456
457
458
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
459
            cumulative_logprobs: List[float] = [
460
461
                seq_group.seq_data[seq_id].cumulative_logprob
                for seq_id in seq_ids
462
            ]
463
            cumulative_logprobs_tensor = torch.tensor(
464
465
466
467
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
468
                                  cumulative_logprobs_tensor.unsqueeze(dim=1))
469
470
471
472
473
474
475
476
477
478
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
479
480


481
482
483
484
485
486
487
488
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
489
    seq_groups: Optional[List[SequenceGroupToSample]] = None,
Nick Hill's avatar
Nick Hill committed
490
) -> torch.Tensor:
491
    if num_samples > 1:
492
        probs = probs.repeat_interleave(num_samples, dim=0)
Nick Hill's avatar
Nick Hill committed
493
494
495
496
497
    q = torch.empty_like(probs)
    if seq_groups is None:
        q.exponential_()
    else:
        sample_idx = 0
498
499
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
500
501
502
503
504
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            q[sample_idx:sample_idx +
              stride].exponential_(generator=seq_group.generator)
            sample_idx += stride
505
506
507
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def _top_k_top_p_multinomial_with_flashinfer(
        probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
        num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
    max_top_k_round = 32
    if num_samples > 1:
        probs = probs.repeat_interleave(num_samples, dim=0)
        top_ks = top_ks.repeat_interleave(num_samples)
        top_ps = top_ps.repeat_interleave(num_samples)
    batch_size = probs.shape[0]
    uniform_samples = torch.empty((max_top_k_round, batch_size),
                                  device=probs.device)
    if seq_groups is None:
        uniform_samples.uniform_()
    else:
        sample_idx = 0
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            uniform_samples[:, sample_idx:sample_idx +
                            stride].uniform_(generator=seq_group.generator)
            sample_idx += stride
    batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
        probs,
        uniform_samples,
        top_ks,
        top_ps,
    )
    if not success.all():
        warnings.warn("FlashInfer rejection sampling failed, fallback.",
                      stacklevel=1)
        probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
        probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
        batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
            probs, uniform_samples[0])
    return batch_next_token_ids.view(-1, num_samples)


546
def _sample_with_torch(
547
548
    probs: torch.Tensor,
    logprobs: torch.Tensor,
549
    sampling_metadata: SamplingMetadata,
550
    sampling_tensors: SamplingTensors,
551
552
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
553
554
555
556
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
    categorized_seq_group_ids: Dict[SamplingType,
                                    List[int]] = {t: []
                                                  for t in SamplingType}
557
558
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
559
        sampling_params = seq_group.sampling_params
560
561
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
562
563

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
564
565
566
    sample_metadata: Dict[SamplingType,
                          Tuple[List[int], List[SequenceGroupToSample]]] = {}
    multinomial_samples: Dict[SamplingType, torch.Tensor] = {}
567

568
569
570
571
572
573
574
575
576
    # Create output tensor for sampled token ids.
    if include_gpu_probs_tensor:
        sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
                                               1,
                                               dtype=torch.long,
                                               device=logprobs.device)
    else:
        sampled_token_ids_tensor = None

577
578
    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
579
    for sampling_type in SamplingType:
580
        sample_indices = categorized_sample_indices[sampling_type][:, 0]
581
        num_tokens = len(sample_indices)
582
583
        if num_tokens == 0:
            continue
584

585
586
587
588
        seq_group_id = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
        sample_metadata[sampling_type] = (seq_group_id, seq_groups)
        long_sample_indices = sample_indices.long()
589
        if sampling_type == SamplingType.GREEDY:
590
            greedy_samples = torch.argmax(logprobs[long_sample_indices],
591
                                          dim=-1)
592

593
            if sampled_token_ids_tensor is not None:
594
595
596
597
598
599
600
601
602
603
604
605
                # Store sampled tokens in output tensor.
                sampled_token_ids_tensor[
                    long_sample_indices] = greedy_samples.unsqueeze(-1)

            if modify_greedy_probs:
                # If required, modify the probabilities such that sampling from
                # the modified distribution would always sample the argmax
                # token id.
                _modify_greedy_probs_inplace(logprobs, probs,
                                             long_sample_indices,
                                             greedy_samples)

Nick Hill's avatar
Nick Hill committed
606
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
607
            max_best_of_in_batch = 1
608
609
610
            for seq_group in seq_groups:
                if seq_group.is_prompt:
                    sampling_params = seq_group.sampling_params
611
612
                    max_best_of_in_batch = max(max_best_of_in_batch,
                                               sampling_params.best_of)
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
                              seq_groups)

            if flashinfer_top_k_top_p_sampling is not None:
                multinomial_samples[
                    sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
                        probs[long_sample_indices],
                        sampling_tensors.top_ks[long_sample_indices],
                        sampling_tensors.top_ps[long_sample_indices],
                        max_best_of_in_batch,
                        seq_groups_arg,
                    )
            else:
                multinomial_samples[sampling_type] = _multinomial(
                    probs[long_sample_indices],
                    max_best_of_in_batch,
                    seq_groups=seq_groups_arg)
630

631
            if sampled_token_ids_tensor is not None:
632
                # Store sampled tokens in output tensor.
633
634
                sampled_token_ids_tensor[long_sample_indices] = \
                    multinomial_samples[sampling_type].to(torch.long)
635

636
637
638
639
640
641
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

    # GPU<->CPU sync happens in the loop below.
642
    # This also converts the sample output to Python objects.
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    if not sampling_metadata.skip_sampler_cpu_output:
        for sampling_type in SamplingType:
            if sampling_type not in sample_metadata:
                continue
            (seq_group_id, seq_groups) = sample_metadata[sampling_type]
            if sampling_type == SamplingType.GREEDY:
                sample_results = _greedy_sample(seq_groups, greedy_samples)
            elif sampling_type in (SamplingType.RANDOM,
                                   SamplingType.RANDOM_SEED):
                sample_results = _random_sample(
                    seq_groups, multinomial_samples[sampling_type])
            elif sampling_type == SamplingType.BEAM:
                sample_results = _beam_search_sample(seq_groups,
                                                     beam_search_logprobs)
            sample_results_dict.update(zip(seq_group_id, sample_results))

        sample_results = [
            sample_results_dict.get(i, ([], []))
            for i in range(len(sampling_metadata.seq_groups))
        ]
    else:
        sample_results = []
665

666
    return sample_results, sampled_token_ids_tensor
667
668


669
670
671
672
673
def _sample_with_triton_kernel(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
674
675
676
677
) -> SampleResultType:
    categorized_seq_group_ids: Dict[SamplingType,
                                    List[int]] = {t: []
                                                  for t in SamplingType}
678
679
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
680
        sampling_params = seq_group.sampling_params
681
682
683
684
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
685
686
687
    sample_metadata: Dict[SamplingType,
                          Tuple[List[int], List[SequenceGroupToSample],
                                torch.Tensor, torch.Tensor]] = {}
688
689
690
691
692
693
694
695
696
697
    max_best_of_in_batch = 1

    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
    for sampling_type in SamplingType:
        sample_indices = categorized_sample_indices[sampling_type][:, 0]
        sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
        num_tokens = len(sample_indices)
        if num_tokens == 0:
            continue
698
699
700
701
        seq_group_id = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
        sample_metadata[sampling_type] = (seq_group_id, seq_groups,
                                          sample_indices,
702
703
704
                                          sampled_token_indices)
        if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
                             SamplingType.RANDOM_SEED):
705
706
707
            for seq_group in seq_groups:
                if seq_group.is_prompt:
                    sampling_params = seq_group.sampling_params
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
                    max_best_of_in_batch = max(max_best_of_in_batch,
                                               sampling_params.best_of)
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

    sampled_tokens, _, _ = sample_triton(
        probs=probs,
        seeds=sampling_tensors.sampling_seeds,
        max_best_of=max_best_of_in_batch,
        sample_indices=sampling_tensors.sample_indices,
        logprobs=logprobs,
        # don't save logprobs because we have logic for that below
        # TODO: use this instead of the CPU-based logic below
        save_logprobs=False,
    )

    # GPU<->CPU sync happens in the loop below.

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
731
        (seq_group_id, seq_groups, sample_indices,
732
733
734
735
736
737
         sampled_token_indices) = sample_metadata[sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(
                seq_groups, sampled_tokens[sampled_token_indices][:, 0])
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
            sample_results = _random_sample(
738
                seq_groups, sampled_tokens[sampled_token_indices])
739
        elif sampling_type == SamplingType.BEAM:
740
            sample_results = _beam_search_sample(seq_groups,
741
                                                 beam_search_logprobs)
742
        sample_results_dict.update(zip(seq_group_id, sample_results))
743
744

    sample_results = [
745
        sample_results_dict.get(i, ([], []))
746
747
748
749
750
751
        for i in range(len(sampling_metadata.seq_groups))
    ]
    return sample_results


def _sample(
752
753
754
755
756
757
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
758
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
759
760
761
762
763
764
765
766
767
768
    """
    Args:
        probs: (num_query_tokens_in_batch, num_vocab)
        logprobs: (num_query_tokens_in_batch, num_vocab)
        sampling_metadata: The metadata for a batch for sampling.
        sampling_tensors: Tensors that include sampling related metadata.

    Returns:
        (next_token_ids, parent_seq_ids) for each seq group in a batch.
            If sampling is skipped, it returns ([], [])
769
        sampled_token_ids_tensor: A tensor of sampled token ids.
770
    """
771
772
773
774
    return _sample_with_torch(
        probs,
        logprobs,
        sampling_metadata,
775
        sampling_tensors,
776
777
778
        include_gpu_probs_tensor=include_gpu_probs_tensor,
        modify_greedy_probs=modify_greedy_probs,
    )
779
780
781
782
783
784

    # TODO: Enable once Triton kernel & associated code is faster.
    # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
    #                                   sampling_tensors)


785
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
786
787
788
789
790
791
    """
    This function calculates the ranks of the chosen tokens in a logprob tensor.

    Args:
        x (torch.Tensor): 2D logprob tensor of shape (N, M)
                        where N is the no. of tokens and M is the vocab dim.
792
        indices (torch.Tensor): List of chosen token indices.
793
794
795

    Returns:
        torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
796
                    Each element in the returned tensor represents the rank
797
798
                    of the chosen token in the input logprob tensor.
    """
799
800
    vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
             indices]
801
802
803
    result = (x > vals[:, None])
    del vals
    return result.sum(1).add_(1)
804
805


806
807
def _get_logprobs(
    logprobs: torch.Tensor,
808
    sampling_metadata: SamplingMetadata,
809
    sample_results: SampleResultType,
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
    """Return sample lobprobs and prompt logprobs.

    The logic consists of 3 parts.
    - Select indices to compute logprob from, ranks of token ids, and
        the top k token ids from logprobs.
    - Compute prompt logprobs if required.
    - Compute sample logprobs if required.

    Args:
        logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
            logprob per vocab. Sequence groups' query tokens are batched in a
            single flattened tensor. For example, assuming there are N
            seq groups, it is sorted by prefill tokens for seq_group_1 (if
            prompt logprob is enabled), decode tokens for seq_group_1 (if
            sampling is required), prefill tokens for seq_group_2, ...
        sampling_metadata: The sampling metadata.
        sample_results: (num_seq_groups) The tuple of (next_token_ids,
            parent_ids) for each sequence group. When beam search is enabled,
            sample_results can contain different number of seq_ids from
            sampling_metadata.seq_groups. It is because beam search creates
            2 * BEAM_WIDTH number of samples (whereas there are only up to
            BEAM_WIDTH number of seq_ids).

    Returns:
        A tuple of prompt and sample logprobs per sequence group in a batch.
    """
    # The index of query token to calculate logprobs. It includes both
    # prompt and sample logprob indices.
    query_indices: List[int] = []
    # The next token ids to get the logprob value from.
    next_token_ids: List[int] = []
    # The largest requested number of logprobs. We find logprobs as many as the
843
844
845
846
847
    # largest num logprobs in this API. If every logprobs is None, it will be
    # set to -1.
    largest_num_logprobs = -1
    # If beam search is enabled.
    use_beam_search = False
848
849
850
851
852
853
854
855
856

    # Select indices to compute logprob from, ranks of token ids, and the top
    # k token ids from logprobs.
    for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
                                          sample_results):
        sampling_params = seq_group.sampling_params

        # Update indices and tokens for prompt logprobs.
        if (seq_group.is_prompt
857
858
859
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
            next_prompt_tokens = _get_next_prompt_tokens(seq_group)
            query_indices.extend(seq_group.prompt_logprob_indices)
            next_token_ids.extend(next_prompt_tokens)

        # Update indices and next tokenes for sample logprob.
        if seq_group.do_sample:
            token_ids, parent_seq_ids = sample_result
            # NOTE: We cannot directly use sample_indices because
            # sample_indices only contain parent seq_ids of a previous step.
            # The current step may have different number of seq_ids, and
            # we can obtain it from `sample_result[1]`.
            query_idx = seq_group.sample_indices[0]
            query_indices.extend(
                [query_idx + parent_id for parent_id in parent_seq_ids])
            next_token_ids.extend(token_ids)

            if sampling_params.logprobs is not None:
                largest_num_logprobs = max(largest_num_logprobs,
                                           sampling_params.logprobs)

880
881
            use_beam_search = use_beam_search or sampling_params.use_beam_search

882
883
884
        assert len(next_token_ids) == len(query_indices)

    if len(query_indices) == 0:
885
886
        empty_sampled_logprob: SampleLogprobs = []
        empty_prompt_logprob: Optional[PromptLogprobs] = None
887
888
        return [empty_prompt_logprob], [empty_sampled_logprob]

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
    selected_logprobs, ranks = None, None
    top_logprobs, top_token_ids = None, None

    # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
    # skip the whole logprob calculation.
    if largest_num_logprobs >= 0 or use_beam_search:
        query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
        next_token_ids_gpu = torch.tensor(next_token_ids,
                                          device=logprobs.device)

        # (num_selected_query_tokens, num_logprobs). Note that query_indices can
        # contain duplicates if beam search is enabled.
        selected_logprobs = logprobs[[
            query_indices_gpu,
            next_token_ids_gpu,
        ]]
        ranks = _get_ranks(
            logprobs[query_indices_gpu],
            next_token_ids_gpu,
        )
        assert selected_logprobs.shape[0] == ranks.shape[0]

        # We need to compute top k only if there exists logprobs > 0.
        if largest_num_logprobs > 0:
            # Logprobs of topk tokens for a batch of sequence groups.
            # (num_query_tokens_across_batch).
            top_logprobs, top_token_ids = torch.topk(logprobs,
                                                     largest_num_logprobs,
                                                     dim=-1)
            top_logprobs = top_logprobs.to('cpu')
            top_token_ids = top_token_ids.to('cpu')
920

921
922
        selected_logprobs = selected_logprobs.to('cpu')
        ranks = ranks.to('cpu')
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961

    # Find prompt/sample logprobs.
    prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
    sample_logprobs_per_seq_group: List[SampleLogprobs] = []
    top_logprob_idx = 0
    selected_logprobs_idx = 0

    for seq_group, sample_result in zip(sampling_metadata.seq_groups,
                                        sample_results):
        (prompt_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_prompt_logprob_if_needed(
             seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
             selected_logprobs_idx, top_logprob_idx)
        prompt_logprobs_per_seq_group.append(prompt_logprobs)

        (sampled_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_sampled_logprob_if_needed(
             seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
             top_logprobs, selected_logprobs_idx, top_logprob_idx)
        sample_logprobs_per_seq_group.append(sampled_logprobs)

    return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group


def _get_prompt_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the prompt logprob from a sequence group if needed."""
    sampling_params = seq_group.sampling_params
    is_prompt = seq_group.is_prompt

    # Find prompt logprobs
    prompt_logprobs: Optional[PromptLogprobs] = None
962
    if is_prompt and sampling_params.prompt_logprobs is not None:
963
964
965
        prompt_logprobs = []
        num_logprobs = sampling_params.prompt_logprobs
        next_prompt_tokens = _get_next_prompt_tokens(seq_group)
966
967
968
969
970
971
972
973
974
        # Pre-select indexes and create a list. It is faster than calling .item
        # repetitively.
        selected_logprob_items = selected_logprobs[
            selected_logprobs_idx:selected_logprobs_idx +
            len(next_prompt_tokens)].tolist()
        rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                           len(next_prompt_tokens)].tolist()

        for idx, token_id in enumerate(next_prompt_tokens):
975
976
977
            # Calculate the prompt logprob of the real prompt tokens.
            # {token_id: (logprob, rank_from_vocab)}
            prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
978
                token_id: (selected_logprob_items[idx], rank_items[idx])
979
            }
980

981
982
            # Add top K prompt logprobs along with its rank.
            if num_logprobs > 0:
983
984
985
986
987
988
989
990
991
992
993
994
                top_ids = top_token_ids[
                    top_logprob_idx, :num_logprobs].tolist()
                top_probs = top_logprobs[
                    top_logprob_idx, :num_logprobs].tolist()
                # Top K is already sorted by rank, so we can use 1 ~
                # num_logprobs + 1 for rank.
                top_ranks = range(1, num_logprobs + 1)
                prompt_logprobs_dict.update({
                    top_id: (top_prob, rank)
                    for top_id, top_prob, rank in zip(top_ids, top_probs,
                                                      top_ranks)
                })
995
996
997
998
999
1000
            prompt_logprobs.append({
                token_id: Logprob(*logprob_and_rank)
                for token_id, logprob_and_rank in prompt_logprobs_dict.items()
            })
            # + 1 to go to the next prompt token.
            top_logprob_idx += 1
1001
1002
1003

        # + len(next_prompt_tokens) to go to the next prompt.
        selected_logprobs_idx += len(next_prompt_tokens)
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    return prompt_logprobs, top_logprob_idx, selected_logprobs_idx


def _get_sampled_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    sample_result: Tuple[List[int], List[int]],
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the sample logprob if needed."""
    seq_ids = seq_group.seq_ids
1019
1020
    num_logprobs = seq_group.sampling_params.logprobs
    use_beam_search = seq_group.sampling_params.use_beam_search
1021
1022
1023
1024
1025
    sampled_logprobs: SampleLogprobs = []
    next_token_ids, parent_seq_ids = sample_result

    if seq_group.do_sample:
        assert len(next_token_ids) > 0
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        if num_logprobs is None and not use_beam_search:
            for next_token_id in next_token_ids:
                # Use a dummy logprob
                sampled_logprobs.append({next_token_id: Logprob(inf)})
        else:
            # Pre-select items from tensor. tolist() is faster than repetitive
            # `.item()` calls.
            selected_logprob_items = selected_logprobs[
                selected_logprobs_idx:selected_logprobs_idx +
                len(next_token_ids)].tolist()
            rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                               len(next_token_ids)].tolist()
            for idx, (next_token_id, parent_id) in enumerate(
                    zip(next_token_ids, parent_seq_ids)):
                # Get the logprob of a sampled token.
                sampled_logprobs_dict = {
                    next_token_id:
                    (selected_logprob_items[idx], rank_items[idx])
                }
                if num_logprobs is not None and num_logprobs > 0:
                    # Get top K logprobs.
                    top_ids = top_token_ids[top_logprob_idx +
                                            parent_id, :num_logprobs].tolist()
                    top_probs = top_logprobs[
                        top_logprob_idx + parent_id, :num_logprobs].tolist()
                    # Top K is already sorted by rank, so we can use 1 ~
                    # num_logprobs + 1 for rank.
                    top_ranks = range(1, num_logprobs + 1)
                    sampled_logprobs_dict.update({
                        top_id: (top_prob, rank)
                        for top_id, top_prob, rank in zip(
                            top_ids, top_probs, top_ranks)
                    })

                sampled_logprobs.append({
                    token_id: Logprob(*logprob_and_rank)
                    for token_id, logprob_and_rank in
                    sampled_logprobs_dict.items()
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
                })

        # NOTE: This part of code is not intuitive. `selected_logprobs` include
        # logprobs for the current step, which has len(next_token_ids) tokens
        # per sequence group. `logprobs` includes logprobs from the previous
        # steps, which has len(seq_ids) tokens per sequence group.

        # Iterate to the next sequence group in a batch.
        selected_logprobs_idx += len(next_token_ids)
        # Iterate to the next sequence group in a batch.
1074
1075
        top_logprob_idx += len(seq_ids)
    return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
1076
1077


1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
                                 sample_indices: torch.Tensor,
                                 greedy_samples: torch.Tensor) -> None:
    """Modify the probability distributions of the greedily-sampled tokens such
    that each sampled token has a "probability" of 1.0. This is required by
    speculative decoding, which depends on the sampling method being encoded
    within the probability distribution for correctness.

    # Why do we only need to do this for greedy sampling?

    vLLM's sampler performs the following steps for greedy or multinomial
    (random) sampling:
        1. Get logits from model.
        2. Modify logits according to per-sequence sampling parameters.
            - Multiply by temperature, top-k and top-p masking, penalize tokens
                according to their frequency, etc.
        3. Sample a token.
            - Random sampling simply samples from the modified probability
                distribution.
            - Greedy sampling performs `argmax` to obtain the token with the
                highest likelihood.
1099

1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    Ignoring greedy sampling for a moment, we find that the computed probability
    distribution has the following property: we can sample from it independently
    and find that the token sampled by the Sampler has a frequency corresponding
    to how often we see it in our sampling. In other words, for tokens sampled
    with vLLM's random SamplingType, the computed probability distribution
    encodes the sampling methodology completely.

    Greedy sampling does not normally have this property. vLLM modifies logits
    according to sampling params, then performs `argmax`, then returns the
    sampled token and the computed probability distribution. If we sample from
    the distribution, we'll find the likelihood of the greedily-sampled token
    is not always 1.0.

    Since lossless speculative decoding requires that the sampling methodology
    be encoded within the probability distribution, we are motivated to modify
    the probability distribution such that the sampled token has probability 1
    when speculative decoding is used.

    NOTE: Alternatively, we could use an extremely low temperature to achieve
    greedy sampling using multinomial computation and unite the codepaths. This
    has implications on the overall design of the sampler, e.g. how to record
    accurate logprobs for the user, so this improvement is deferred to later.
    """
1123
    # NOTE: logprobs are not modified so they can be returned to the user.
1124
1125
1126
1127
    probs[sample_indices, :] = 0
    probs[sample_indices, greedy_samples] = 1.0


1128
def _build_sampler_output(
1129
    sample_results: SampleResultType,
1130
    sampling_metadata: SamplingMetadata,
1131
1132
    prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
    sample_logprobs: Optional[List[SampleLogprobs]],
1133
1134
    on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
                                      torch.Tensor]],
1135
    skip_sampler_cpu_output: bool = False,
1136
) -> SamplerOutput:
1137
1138
1139
1140
1141
1142
1143
1144
    """Construct Python objects with the output of sampling.

    Args:
        on_device_tensors: Tuple containing on-device tensors with the
            probabilities used in sampling and the sampled token ids. This
            allows post-processing without copies to CPU/serialization, e.g. in
            speculative decoding rejection sampling.
    """
1145
    sampler_output: List[CompletionSequenceGroupOutput] = []
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
    if not skip_sampler_cpu_output:
        assert prompt_logprobs is not None
        assert sample_logprobs is not None

        for (seq_group, sample_result, group_prompt_logprobs,
             group_sample_logprobs) in zip(sampling_metadata.seq_groups,
                                           sample_results, prompt_logprobs,
                                           sample_logprobs):
            seq_ids = seq_group.seq_ids
            next_token_ids, parent_ids = sample_result
            seq_outputs: List[SequenceOutput] = []
            for parent_id, next_token_id, logprobs in zip(
                    parent_ids, next_token_ids, group_sample_logprobs):
                seq_outputs.append(
                    SequenceOutput(seq_ids[parent_id], next_token_id,
                                   logprobs))
            sampler_output.append(
                CompletionSequenceGroupOutput(seq_outputs,
                                              group_prompt_logprobs))
1165
1166
1167

    # If not specified, store None values in SamplerOutput.
    if on_device_tensors is not None:
1168
1169
        (sampled_token_probs, logprobs_tensor,
         sampled_token_ids) = on_device_tensors
1170
    else:
1171
1172
        sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
                                                                   None)
1173
1174
1175
1176
1177

    return SamplerOutput(
        outputs=sampler_output,
        sampled_token_probs=sampled_token_probs,
        sampled_token_ids=sampled_token_ids,
1178
        logprobs=logprobs_tensor,
1179
    )
1180
1181


1182
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    """Get a list of next prompt tokens to compute logprob from a
        given sequence group.

    It is used to compute prompt logprob. Imagine you have logprob for each
    query token. Query token needs to know the next prompt token id to compute
    prompt logprob. This is a helper to obtain next prompt token ids.

    This API has to be used only when the caller knows seq_group is in prefill
    stage.

    Returns:
        A list of next prompt tokens to compute logprob.
    """
    assert seq_group.is_prompt, (
        "Caller should ensure the sequence group is in a prefill stage.")
    seq_ids = seq_group.seq_ids
1199
1200
    query_len = seq_group.query_len
    assert query_len is not None
1201
1202
1203
1204
1205
1206
1207
    # prompt has only 1 seq id.
    assert len(seq_ids) == 1
    seq_data = seq_group.seq_data[seq_ids[0]]
    computed_len = seq_data.get_num_computed_tokens()
    prompt_tokens = seq_data.prompt_token_ids
    # +1 because we are looking for a next prompt token.
    next_token_index_start = computed_len + 1
1208
    next_token_index_end = min(computed_len + query_len + 1,
1209
1210
1211
1212
                               len(prompt_tokens))
    next_prompt_tokens = prompt_tokens[
        next_token_index_start:next_token_index_end]
    return next_prompt_tokens