sampler.py 58.2 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
import itertools
3
import warnings
4
from dataclasses import dataclass
5
from importlib.util import find_spec
6
from math import inf
7
from typing import Dict, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import msgspec
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
import torch
import torch.nn as nn

13
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
14
15
16
17
18
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
    from vllm.model_executor.layers.ops.sample import sample as sample_triton

19
import vllm.envs as envs
20
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
21
22
23
                                                   SamplingTensors,
                                                   SequenceGroupToSample)
from vllm.sampling_params import SamplingType
24
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
25
                           PromptLogprobs, SampleLogprobs, SequenceOutput)
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
30
31
32
33
34
35
36
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
    import flashinfer.sampling
    # yapf: disable
    from flashinfer.sampling import (
        top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

    # yapf: enable
else:
    flashinfer_top_k_top_p_sampling = None

37
38
39
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Types of temporary data structures used for
# computing sample_result
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
                                              List[SequenceGroupToSample]]]
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]


# Encapsulates temporary data structures for computing
# sample_result.
#
# * For multi-step scheduling: must be returned
#   by `Sampler.forward()` and used later to compute the pythonized
#   sample_result
#
# * For single-step scheduling: consumed immediately
#   inside `Sampler.forward()` to compute pythonized sample_result.
@dataclass
class SampleResultArgsType:
    sample_metadata: SampleMetadataType
    multinomial_samples: MultinomialSamplesType
    sample_results_dict: SampleResultsDictType
    sampling_metadata: SamplingMetadata
    greedy_samples: Optional[torch.Tensor]
    beam_search_logprobs: Optional[torch.Tensor]


# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# sample result types
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]

# Abbreviation of the _sample() return type
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]


class SamplerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

    This data structure implements methods, so it can be used like a list, but
    also has optional fields for device tensors.
    """

    outputs: List[CompletionSequenceGroupOutput]

    # On-device tensor containing probabilities of each token.
    sampled_token_probs: Optional[torch.Tensor] = None

    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

    # Holds either (1) the pythonized sampler result (single-step scheduling)
    # or (2) what will be arguments for later deferred pythonization of the
    # sampler result (muliti-step scheduling)
    deferred_sample_results_args: Optional[SampleResultArgsType] = None

    # On-device tensor containing the sampled token ids.
    sampled_token_ids: Optional[torch.Tensor] = None
    # CPU tensor containing the sampled token ids. Used during multi-step to
    # return the sampled token ids from last rank to AsyncLLMEngine to be
    # 'broadcasted' to all other PP ranks for next step.
    sampled_token_ids_cpu: Optional[torch.Tensor] = None

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None

    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

    # Optional prefill hidden states from the model
    # (used for models like EAGLE).
    prefill_hidden_states: Optional[torch.Tensor] = None

    # Time taken in the forward pass for this across all workers
    model_forward_time: Optional[float] = None

    # Time taken in the model execute function. This will include model forward,
    # block/sync across workers, cpu-gpu sync time and sampling time.
    model_execute_time: Optional[float] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")

150

Woosuk Kwon's avatar
Woosuk Kwon committed
151
class Sampler(nn.Module):
152
153
154
155
156
157
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
158
    3. Apply presence, frequency and repetition penalties.
159
160
161
162
163
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
164
165
166
167
168
169

    The structure of the logits tensor is coupled with the seq_groups in
    sampling_metadata. Typically, each sequence in each seq_group has one row in
    logits for the next token to be sampled; however, for a seq_group with a
    prompt request with the prompt_logprobs sampling parameter, there are rows
    in logits for each token in the input prompt.
170
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
171

172
173
174
175
176
177
178
    def __init__(self):
        super().__init__()

        # Whether or not the SamplerOutput should have on-device tensors
        # containing the sampled token ids and probabilities. This is used by
        # speculative decoding.
        self.include_gpu_probs_tensor = False
179
        self.should_modify_greedy_probs_inplace = False
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def _init_sampling_tensors(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ):
        """The goal here is to reuse sampling tensors between similar decode
        runs. This is possible because sampling logic does not change between
        decodes of the same sequences.
        """
        _, vocab_size = logits.shape

        # First free any existing stored sampling tensors.
        # This is necessary because some sampling tensors may
        # have pinned memory.
        self._sampling_tensors = None

        # Initialize new sampling tensors
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)

        self._sampling_tensors = sampling_tensors
        self._do_penalties = do_penalties
        self._do_top_p_top_k = do_top_p_top_k
        self._do_min_p = do_min_p

Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
    def forward(
        self,
209
        logits: torch.Tensor,
210
        sampling_metadata: SamplingMetadata,
211
    ) -> Optional[SamplerOutput]:
212
        """
213
214
215
216
217
218
219
220
221
222
223
224
225
        Single-step scheduling:
        * Perform GPU-side sampling computation & compute
          GPU-side logprobs tensor
        * Pythonize sampling result & logprobs tensor

        Multi-step scheduling:
        * Perform GPU-side sampling computation & compute
          GPU-side logprobs tensor
        * Defer Pythonization of sampling result & logprobs
          tensor
        * Encapsulate arguments required for deferred Pythonization
          in the :class:`SamplerOutput` structure

226
227
228
229
        Args:
            logits: (num_tokens, vocab_size).
            sampling_metadata: Metadata for sampling.
        """
230
        assert logits is not None
231
232
        _, vocab_size = logits.shape

233
        # Prepare sampling tensors with pinned memory to avoid blocking.
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        if not sampling_metadata.reuse_sampling_tensors:
            self._init_sampling_tensors(logits, sampling_metadata)
        elif self._do_penalties:
            # In this case, the sampling tensors logic depends on
            # "output_tokens" of a sequence. As a result, we cannot
            # reuse sampling tensors, since "output_tokens" changes
            # between decode runs.
            self._init_sampling_tensors(logits, sampling_metadata)

        assert self._sampling_tensors is not None
        sampling_tensors = self._sampling_tensors
        do_penalties = self._do_penalties
        do_top_p_top_k = self._do_top_p_top_k
        do_min_p = self._do_min_p

        logits = _apply_min_tokens_penalty(logits, sampling_metadata)
250

251
        # Apply presence and frequency penalties.
252
253
254
255
256
257
        if do_penalties:
            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                      sampling_tensors.output_tokens,
                                      sampling_tensors.presence_penalties,
                                      sampling_tensors.frequency_penalties,
                                      sampling_tensors.repetition_penalties)
258

259
        # Use float32 to apply temperature scaling.
260
        # Use in-place division to avoid creating a new tensor.
261
        logits = logits.to(torch.float)
262
        logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
263

264
        if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
265
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
266
267
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
268
        if do_min_p:
269
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
270

271
272
273
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
274
275
        # Compute the log probabilities.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
276

Woosuk Kwon's avatar
Woosuk Kwon committed
277
        # Sample the next tokens.
278
        maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
279
280
281
282
283
284
285
286
287
            probs,
            logprobs,
            sampling_metadata,
            sampling_tensors,
            include_gpu_probs_tensor=self.include_gpu_probs_tensor,
            modify_greedy_probs=self._should_modify_greedy_probs_inplace,
        )

        if self.include_gpu_probs_tensor:
288
289
290
            # Since we will defer sampler result Pythonization,
            # preserve GPU-side tensors in support of later
            # deferred pythonization of logprobs
291
            assert maybe_sampled_tokens_tensor is not None
292
            on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
293
        else:
294
295
            # Since Pythonization has already happened, don't preserve
            # GPU-side tensors.
296
297
            on_device_tensors = None

298
        # Get the logprobs query results.
299
300
301
        prompt_logprobs = None
        sample_logprobs = None
        if not sampling_metadata.skip_sampler_cpu_output:
302
303
304
305
306
            # Pythonize logprobs now (GPU -> CPU); do not defer.
            assert not isinstance(maybe_deferred_sample_results,
                                  SampleResultArgsType)
            prompt_logprobs, sample_logprobs = get_logprobs(
                logprobs, sampling_metadata, maybe_deferred_sample_results)
307
308

        return _build_sampler_output(
309
            maybe_deferred_sample_results,
310
311
312
313
314
            sampling_metadata,
            prompt_logprobs,
            sample_logprobs,
            on_device_tensors=on_device_tensors,
            skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
315
316
317
318
319
320
321
322
323
324
325
326
327

    @property
    def _should_modify_greedy_probs_inplace(self) -> bool:
        """Whether or not the sampler should modify the probability distribution
        of greedily-sampled tokens such that multinomial sampling would sample
        the greedily-sampled token.

        In other words, if True then we set the probability of the greedily-
        sampled token to 1.

        This is used by speculative decoding, which requires that the sampling
        method be encoded into the probability distribution.
        """
328
        return self.should_modify_greedy_probs_inplace
329
330


331
def _get_bin_counts_and_mask(
332
    tokens: torch.Tensor,
333
334
335
336
337
338
339
    vocab_size: int,
    num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
340
341
                             device=tokens.device)
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
342
343
344
345
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask
346
347


348
349
350
351
def _apply_min_tokens_penalty(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
352
353
354
    """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
        have not been generated yet
    """
355
    # list of indices in logits that will be set to -inf
356
    logits_to_penalize: List[Tuple[int, int]] = []
357
358
359
360
361
362
363
364
365
366
    logits_applied = 0
    for seq_group in sampling_metadata.seq_groups:
        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params

        sample_indices = seq_group.sample_indices
        logits_applied += len(sample_indices) + len(
            seq_group.prompt_logprob_indices)
        if not seq_group.do_sample:
            continue
367

368
        start_idx = sample_indices[0]
369
        min_tokens = sampling_params.min_tokens
370
371
        token_ids_to_penalize = sampling_params.all_stop_token_ids
        if min_tokens > 0 and token_ids_to_penalize:
372
            seqs_to_penalize: List[int] = []
373
            for j, seq_id in enumerate(seq_ids):
374
                seq_data = seq_group.seq_data[seq_id]
375
                if len(seq_data.output_token_ids_array) < min_tokens:
376
                    seqs_to_penalize.append(j)
377
378
379

            if seqs_to_penalize:
                # convert to the index into logits
380
                seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
381
382
383
384
385
386
387
388
389
                # itertools.product pairs each seq index with every token id
                logits_to_penalize.extend(
                    itertools.product(seqs_to_penalize, token_ids_to_penalize))

    if logits_to_penalize:
        # use zip and * to group indices along each dimension
        # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
        logits[tuple(zip(*logits_to_penalize))] = -float("inf")

390
    # verifies that no rows in logits were missed unexpectedly
391
    assert logits_applied == logits.shape[0]
392
393
394
    return logits


395
396
397
398
399
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                     output_tokens_tensor: torch.Tensor,
                     presence_penalties: torch.Tensor,
                     frequency_penalties: torch.Tensor,
                     repetition_penalties: torch.Tensor) -> torch.Tensor:
400
    num_seqs, vocab_size = logits.shape
401
402
    _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
                                              num_seqs)
403
    output_bin_counts, output_mask = _get_bin_counts_and_mask(
404
        output_tokens_tensor, vocab_size, num_seqs)
405

ljss's avatar
ljss committed
406
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
407
    repetition_penalties[~(prompt_mask | output_mask)] = 1.0
ljss's avatar
ljss committed
408
409
410
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

411
412
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
413
414
    logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
415
416
417
    return logits


418
def _apply_top_k_top_p(
419
    logits: torch.Tensor,
420
421
    p: torch.Tensor,
    k: torch.Tensor,
422
) -> torch.Tensor:
423
424
425
426
427
428
429
430
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

    # Apply top-k.
    top_k_mask = logits_sort.size(1) - k.to(torch.long)
    # Get all the top_k values.
    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
    top_k_mask = logits_sort < top_k_mask
    logits_sort.masked_fill_(top_k_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
431
432

    # Apply top-p.
433
    probs_sort = logits_sort.softmax(dim=-1)
434
435
436
437
438
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
    # at least one
    top_p_mask[:, -1] = False
    logits_sort.masked_fill_(top_p_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
439
440

    # Re-sort the probabilities.
441
442
443
444
445
446
    src = torch.arange(logits_idx.shape[-1],
                       device=logits_idx.device).expand_as(logits_idx)
    logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
                                                           index=logits_idx,
                                                           src=src)
    logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
447
    return logits
448
449


Roy's avatar
Roy committed
450
451
def _apply_min_p(
    logits: torch.Tensor,
452
    min_p: torch.Tensor,
Roy's avatar
Roy committed
453
454
455
456
457
458
459
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
460
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
461
    tokens_to_remove = probs < scaled_min_p
462
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
463
464
465
466

    return logits


467
def _greedy_sample(
468
    selected_seq_groups: List[SequenceGroupToSample],
469
    samples: torch.Tensor,
470
) -> SampleResultType:
471
472
473
474
475
476
477
478
479
480
481
482
    """Run greedy sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        samples: (num_selected_samples,) A tensor of samples. The length of
            samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
483
    samples_lst = samples.tolist()
484
    sample_idx = 0
485
    results: SampleResultType = []
486
    for seq_group in selected_seq_groups:
487
488
489
490
491
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
492
493
494
495
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
496
        next_token_ids = [samples_lst[sample_idx]]
497
498
499
500
501
502
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
503
    selected_seq_groups: List[SequenceGroupToSample],
504
    random_samples: torch.Tensor,
505
) -> SampleResultType:
506
507
508
509
510
511
512
513
514
515
516
517
    """Run random sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        random_samples: (num_selected_samples,) A tensor of samples. The
            length of samples could be smaller than selected_seq_groups if
            seq_group.do_sample is False.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
518
    # Find the maximum best_of value of the prompt phase requests.
519
    random_samples = random_samples.cpu()
520
    sample_idx = 0
521
    results: SampleResultType = []
522
523
524
525
526
527
528
529
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        seq_ids = seq_group.seq_ids
        sampling_params = seq_group.sampling_params
        is_prompt = seq_group.is_prompt
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _beam_search_sample(
547
    selected_seq_groups: List[SequenceGroupToSample],
548
    logprobs: torch.Tensor,
549
) -> SampleResultType:
550
551
552
553
554
555
556
557
558
559
560
    """Run beam sampling on a given samples.

    Args:
        selected_seq_groups: A list of sequence groups batched.
        logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
        on selected sample indices.
    Returns:
        Tuple of (next_token_ids, parent_ids). The length of returned list is
        same as the length of selected_seq_groups. If the corresponding
        seq_group has do_sample=False, tuple contains ([], [])
    """
561
562
563
564
565
566
567
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
568
    # NOTE: Beam search is not vectorized, so its speed can be slower than
569
570
    # other sampling methods.
    sample_idx = 0
571
    results: SampleResultType = []
572
573
574
575
576
577
578
    for seq_group in selected_seq_groups:
        if not seq_group.do_sample:
            results.append(([], []))
            continue

        is_prompt = seq_group.is_prompt
        seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
579
580
581
582
583
584
585
586
587
588
589
590
591
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
592
            cumulative_logprobs: List[float] = [
593
594
                seq_group.seq_data[seq_id].cumulative_logprob
                for seq_id in seq_ids
595
            ]
596
            cumulative_logprobs_tensor = torch.tensor(
597
598
599
600
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
601
                                  cumulative_logprobs_tensor.unsqueeze(dim=1))
602
603
604
605
606
607
608
609
610
611
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
612
613


614
615
616
617
618
619
620
621
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
622
    seq_groups: Optional[List[SequenceGroupToSample]] = None,
Nick Hill's avatar
Nick Hill committed
623
) -> torch.Tensor:
624
    if num_samples > 1:
625
        probs = probs.repeat_interleave(num_samples, dim=0)
Nick Hill's avatar
Nick Hill committed
626
627
628
629
630
    q = torch.empty_like(probs)
    if seq_groups is None:
        q.exponential_()
    else:
        sample_idx = 0
631
632
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
633
634
635
636
637
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            q[sample_idx:sample_idx +
              stride].exponential_(generator=seq_group.generator)
            sample_idx += stride
638
639
640
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
def _top_k_top_p_multinomial_with_flashinfer(
        probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
        num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
    max_top_k_round = 32
    if num_samples > 1:
        probs = probs.repeat_interleave(num_samples, dim=0)
        top_ks = top_ks.repeat_interleave(num_samples)
        top_ps = top_ps.repeat_interleave(num_samples)
    batch_size = probs.shape[0]
    uniform_samples = torch.empty((max_top_k_round, batch_size),
                                  device=probs.device)
    if seq_groups is None:
        uniform_samples.uniform_()
    else:
        sample_idx = 0
        for seq_group in seq_groups:
            seq_ids = seq_group.seq_ids
            stride = len(seq_ids) * num_samples
            assert seq_group.generator is not None
            uniform_samples[:, sample_idx:sample_idx +
                            stride].uniform_(generator=seq_group.generator)
            sample_idx += stride
    batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
        probs,
        uniform_samples,
        top_ks,
        top_ps,
    )
    if not success.all():
        warnings.warn("FlashInfer rejection sampling failed, fallback.",
                      stacklevel=1)
        probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
        probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
        batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
            probs, uniform_samples[0])
    return batch_next_token_ids.view(-1, num_samples)


679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def get_pythonized_sample_results(
        sample_result_args: SampleResultArgsType) -> SampleResultType:
    '''This function consumes GPU-side sampler results and computes
    Pythonized CPU-side sampler results (GPU -> CPU sync.)

    Single-step scheduling: this function is invoked at sampling-time
    for immediate Pythonization.

    Multi-step scheduling: Pythonization is deferred until after multiple
    GPU-side steps have been completed.

    Args:
      sample_result_args: GPU-side inputs to the Pythonization process

    Returns:
      Pythonized sampler results
    '''

    (
        sample_metadata,
        sampling_metadata,
        greedy_samples,
        multinomial_samples,
        beam_search_logprobs,
        sample_results_dict,
    ) = (
        sample_result_args.sample_metadata,
        sample_result_args.sampling_metadata,
        sample_result_args.greedy_samples,
        sample_result_args.multinomial_samples,
        sample_result_args.beam_search_logprobs,
        sample_result_args.sample_results_dict,
    )

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
        (seq_group_id, seq_groups) = sample_metadata[sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, greedy_samples)
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
            sample_results = _random_sample(seq_groups,
                                            multinomial_samples[sampling_type])
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups,
                                                 beam_search_logprobs)
        sample_results_dict.update(zip(seq_group_id, sample_results))

    return [
        sample_results_dict.get(i, ([], []))
        for i in range(len(sampling_metadata.seq_groups))
    ]


733
def _sample_with_torch(
734
735
    probs: torch.Tensor,
    logprobs: torch.Tensor,
736
    sampling_metadata: SamplingMetadata,
737
    sampling_tensors: SamplingTensors,
738
739
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
740
741
742
743
744
745
746
747
748
749
750
751
752
) -> SampleReturnType:
    '''Torch-oriented _sample() implementation.

    Single-step scheduling: 
    * Perform GPU-side sampling computation
    * Immediately Pythonize sampling result

    Multi-step scheduling:
    * Perform GPU-side sampling computation
    * Defer Pythonization & preserve GPU-side
      tensors required for Pythonization
    '''

753
754
755
    categorized_seq_group_ids: Dict[SamplingType,
                                    List[int]] = {t: []
                                                  for t in SamplingType}
756
757
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
758
        sampling_params = seq_group.sampling_params
759
760
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
761

762
763
764
765
766
    sample_results_dict: SampleResultsDictType = {}
    sample_metadata: SampleMetadataType = {}
    multinomial_samples: MultinomialSamplesType = {}
    greedy_samples: Optional[torch.Tensor] = None
    beam_search_logprobs: Optional[torch.Tensor] = None
767

768
769
770
771
772
773
774
775
776
    # Create output tensor for sampled token ids.
    if include_gpu_probs_tensor:
        sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
                                               1,
                                               dtype=torch.long,
                                               device=logprobs.device)
    else:
        sampled_token_ids_tensor = None

777
778
    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
779
    for sampling_type in SamplingType:
780
        sample_indices = categorized_sample_indices[sampling_type][:, 0]
781
        num_tokens = len(sample_indices)
782
783
        if num_tokens == 0:
            continue
784

785
786
787
788
        seq_group_id = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
        sample_metadata[sampling_type] = (seq_group_id, seq_groups)
        long_sample_indices = sample_indices.long()
789
        if sampling_type == SamplingType.GREEDY:
790
            greedy_samples = torch.argmax(logprobs[long_sample_indices],
791
                                          dim=-1)
792

793
            if sampled_token_ids_tensor is not None:
794
795
796
797
798
799
800
801
802
803
804
805
                # Store sampled tokens in output tensor.
                sampled_token_ids_tensor[
                    long_sample_indices] = greedy_samples.unsqueeze(-1)

            if modify_greedy_probs:
                # If required, modify the probabilities such that sampling from
                # the modified distribution would always sample the argmax
                # token id.
                _modify_greedy_probs_inplace(logprobs, probs,
                                             long_sample_indices,
                                             greedy_samples)

Nick Hill's avatar
Nick Hill committed
806
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
807
            max_best_of_in_batch = 1
808
809
810
            for seq_group in seq_groups:
                if seq_group.is_prompt:
                    sampling_params = seq_group.sampling_params
811
812
                    max_best_of_in_batch = max(max_best_of_in_batch,
                                               sampling_params.best_of)
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
            seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
                              seq_groups)

            if flashinfer_top_k_top_p_sampling is not None:
                multinomial_samples[
                    sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
                        probs[long_sample_indices],
                        sampling_tensors.top_ks[long_sample_indices],
                        sampling_tensors.top_ps[long_sample_indices],
                        max_best_of_in_batch,
                        seq_groups_arg,
                    )
            else:
                multinomial_samples[sampling_type] = _multinomial(
                    probs[long_sample_indices],
                    max_best_of_in_batch,
                    seq_groups=seq_groups_arg)
830

831
            if sampled_token_ids_tensor is not None:
832
                # Store sampled tokens in output tensor.
833
834
                sampled_token_ids_tensor[long_sample_indices] = \
                    multinomial_samples[sampling_type].to(torch.long)
835

836
837
838
839
840
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

841
842
843
844
845
846
847
848
849
850
    # Encapsulate arguments for computing Pythonized sampler
    # results, whether deferred or otherwise.
    maybe_deferred_args = SampleResultArgsType(
        sampling_metadata=sampling_metadata,
        sample_metadata=sample_metadata,
        multinomial_samples=multinomial_samples,
        greedy_samples=greedy_samples,
        beam_search_logprobs=beam_search_logprobs,
        sample_results_dict=sample_results_dict)

851
    if not sampling_metadata.skip_sampler_cpu_output:
852
853
854
855
856
        # GPU<->CPU sync happens here.
        # This also converts the sampler output to a Python object.
        # Return Pythonized sampler result & sampled token ids
        return get_pythonized_sample_results(
            maybe_deferred_args), sampled_token_ids_tensor
857
    else:
858
859
860
861
862
863
        # Defer sampler result Pythonization; return deferred
        # Pythonization args & sampled token ids
        return (
            maybe_deferred_args,
            sampled_token_ids_tensor,
        )
864
865


866
867
868
869
870
def _sample_with_triton_kernel(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
871
872
873
874
) -> SampleResultType:
    categorized_seq_group_ids: Dict[SamplingType,
                                    List[int]] = {t: []
                                                  for t in SamplingType}
875
876
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
877
        sampling_params = seq_group.sampling_params
878
879
880
881
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
882
883
884
    sample_metadata: Dict[SamplingType,
                          Tuple[List[int], List[SequenceGroupToSample],
                                torch.Tensor, torch.Tensor]] = {}
885
886
887
888
889
890
891
892
893
894
    max_best_of_in_batch = 1

    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
    for sampling_type in SamplingType:
        sample_indices = categorized_sample_indices[sampling_type][:, 0]
        sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
        num_tokens = len(sample_indices)
        if num_tokens == 0:
            continue
895
896
897
898
        seq_group_id = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
        sample_metadata[sampling_type] = (seq_group_id, seq_groups,
                                          sample_indices,
899
900
901
                                          sampled_token_indices)
        if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
                             SamplingType.RANDOM_SEED):
902
903
904
            for seq_group in seq_groups:
                if seq_group.is_prompt:
                    sampling_params = seq_group.sampling_params
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
                    max_best_of_in_batch = max(max_best_of_in_batch,
                                               sampling_params.best_of)
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

    sampled_tokens, _, _ = sample_triton(
        probs=probs,
        seeds=sampling_tensors.sampling_seeds,
        max_best_of=max_best_of_in_batch,
        sample_indices=sampling_tensors.sample_indices,
        logprobs=logprobs,
        # don't save logprobs because we have logic for that below
        # TODO: use this instead of the CPU-based logic below
        save_logprobs=False,
    )

    # GPU<->CPU sync happens in the loop below.

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
928
        (seq_group_id, seq_groups, sample_indices,
929
930
931
932
933
934
         sampled_token_indices) = sample_metadata[sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(
                seq_groups, sampled_tokens[sampled_token_indices][:, 0])
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
            sample_results = _random_sample(
935
                seq_groups, sampled_tokens[sampled_token_indices])
936
        elif sampling_type == SamplingType.BEAM:
937
            sample_results = _beam_search_sample(seq_groups,
938
                                                 beam_search_logprobs)
939
        sample_results_dict.update(zip(seq_group_id, sample_results))
940
941

    sample_results = [
942
        sample_results_dict.get(i, ([], []))
943
944
945
946
947
948
        for i in range(len(sampling_metadata.seq_groups))
    ]
    return sample_results


def _sample(
949
950
951
952
953
954
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
955
) -> SampleReturnType:
956
957
958
959
960
961
962
963
964
965
    """
    Args:
        probs: (num_query_tokens_in_batch, num_vocab)
        logprobs: (num_query_tokens_in_batch, num_vocab)
        sampling_metadata: The metadata for a batch for sampling.
        sampling_tensors: Tensors that include sampling related metadata.

    Returns:
        (next_token_ids, parent_seq_ids) for each seq group in a batch.
            If sampling is skipped, it returns ([], [])
966
        sampled_token_ids_tensor: A tensor of sampled token ids.
967
    """
968
969
970
971
    return _sample_with_torch(
        probs,
        logprobs,
        sampling_metadata,
972
        sampling_tensors,
973
974
975
        include_gpu_probs_tensor=include_gpu_probs_tensor,
        modify_greedy_probs=modify_greedy_probs,
    )
976
977
978
979
980
981

    # TODO: Enable once Triton kernel & associated code is faster.
    # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
    #                                   sampling_tensors)


982
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
983
984
985
986
987
988
    """
    This function calculates the ranks of the chosen tokens in a logprob tensor.

    Args:
        x (torch.Tensor): 2D logprob tensor of shape (N, M)
                        where N is the no. of tokens and M is the vocab dim.
989
        indices (torch.Tensor): List of chosen token indices.
990
991
992

    Returns:
        torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
993
                    Each element in the returned tensor represents the rank
994
995
                    of the chosen token in the input logprob tensor.
    """
996
997
    vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
             indices]
998
999
1000
    result = (x > vals[:, None])
    del vals
    return result.sum(1).add_(1)
1001
1002


1003
def get_logprobs(
1004
    logprobs: torch.Tensor,
1005
    sampling_metadata: SamplingMetadata,
1006
    sample_results: SampleResultType,
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
    """Return sample lobprobs and prompt logprobs.

    The logic consists of 3 parts.
    - Select indices to compute logprob from, ranks of token ids, and
        the top k token ids from logprobs.
    - Compute prompt logprobs if required.
    - Compute sample logprobs if required.

    Args:
        logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
            logprob per vocab. Sequence groups' query tokens are batched in a
            single flattened tensor. For example, assuming there are N
            seq groups, it is sorted by prefill tokens for seq_group_1 (if
            prompt logprob is enabled), decode tokens for seq_group_1 (if
            sampling is required), prefill tokens for seq_group_2, ...
        sampling_metadata: The sampling metadata.
        sample_results: (num_seq_groups) The tuple of (next_token_ids,
            parent_ids) for each sequence group. When beam search is enabled,
            sample_results can contain different number of seq_ids from
            sampling_metadata.seq_groups. It is because beam search creates
            2 * BEAM_WIDTH number of samples (whereas there are only up to
            BEAM_WIDTH number of seq_ids).

    Returns:
        A tuple of prompt and sample logprobs per sequence group in a batch.
    """
    # The index of query token to calculate logprobs. It includes both
    # prompt and sample logprob indices.
    query_indices: List[int] = []
    # The next token ids to get the logprob value from.
    next_token_ids: List[int] = []
    # The largest requested number of logprobs. We find logprobs as many as the
1040
1041
1042
1043
1044
    # largest num logprobs in this API. If every logprobs is None, it will be
    # set to -1.
    largest_num_logprobs = -1
    # If beam search is enabled.
    use_beam_search = False
1045
1046
1047
1048
1049
1050
1051
1052
1053

    # Select indices to compute logprob from, ranks of token ids, and the top
    # k token ids from logprobs.
    for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
                                          sample_results):
        sampling_params = seq_group.sampling_params

        # Update indices and tokens for prompt logprobs.
        if (seq_group.is_prompt
1054
1055
1056
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            next_prompt_tokens = _get_next_prompt_tokens(seq_group)
            query_indices.extend(seq_group.prompt_logprob_indices)
            next_token_ids.extend(next_prompt_tokens)

        # Update indices and next tokenes for sample logprob.
        if seq_group.do_sample:
            token_ids, parent_seq_ids = sample_result
            # NOTE: We cannot directly use sample_indices because
            # sample_indices only contain parent seq_ids of a previous step.
            # The current step may have different number of seq_ids, and
            # we can obtain it from `sample_result[1]`.
            query_idx = seq_group.sample_indices[0]
            query_indices.extend(
                [query_idx + parent_id for parent_id in parent_seq_ids])
            next_token_ids.extend(token_ids)

            if sampling_params.logprobs is not None:
                largest_num_logprobs = max(largest_num_logprobs,
                                           sampling_params.logprobs)

1077
1078
            use_beam_search = use_beam_search or sampling_params.use_beam_search

1079
1080
1081
        assert len(next_token_ids) == len(query_indices)

    if len(query_indices) == 0:
1082
1083
        empty_sampled_logprob: SampleLogprobs = []
        empty_prompt_logprob: Optional[PromptLogprobs] = None
1084
1085
        return [empty_prompt_logprob], [empty_sampled_logprob]

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
    selected_logprobs, ranks = None, None
    top_logprobs, top_token_ids = None, None

    # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
    # skip the whole logprob calculation.
    if largest_num_logprobs >= 0 or use_beam_search:
        query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
        next_token_ids_gpu = torch.tensor(next_token_ids,
                                          device=logprobs.device)

        # (num_selected_query_tokens, num_logprobs). Note that query_indices can
        # contain duplicates if beam search is enabled.
        selected_logprobs = logprobs[[
            query_indices_gpu,
            next_token_ids_gpu,
        ]]
        ranks = _get_ranks(
            logprobs[query_indices_gpu],
            next_token_ids_gpu,
        )
        assert selected_logprobs.shape[0] == ranks.shape[0]

        # We need to compute top k only if there exists logprobs > 0.
        if largest_num_logprobs > 0:
            # Logprobs of topk tokens for a batch of sequence groups.
            # (num_query_tokens_across_batch).
            top_logprobs, top_token_ids = torch.topk(logprobs,
                                                     largest_num_logprobs,
                                                     dim=-1)
            top_logprobs = top_logprobs.to('cpu')
            top_token_ids = top_token_ids.to('cpu')
1117

1118
1119
        selected_logprobs = selected_logprobs.to('cpu')
        ranks = ranks.to('cpu')
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158

    # Find prompt/sample logprobs.
    prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
    sample_logprobs_per_seq_group: List[SampleLogprobs] = []
    top_logprob_idx = 0
    selected_logprobs_idx = 0

    for seq_group, sample_result in zip(sampling_metadata.seq_groups,
                                        sample_results):
        (prompt_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_prompt_logprob_if_needed(
             seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
             selected_logprobs_idx, top_logprob_idx)
        prompt_logprobs_per_seq_group.append(prompt_logprobs)

        (sampled_logprobs, top_logprob_idx,
         selected_logprobs_idx) = _get_sampled_logprob_if_needed(
             seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
             top_logprobs, selected_logprobs_idx, top_logprob_idx)
        sample_logprobs_per_seq_group.append(sampled_logprobs)

    return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group


def _get_prompt_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the prompt logprob from a sequence group if needed."""
    sampling_params = seq_group.sampling_params
    is_prompt = seq_group.is_prompt

    # Find prompt logprobs
    prompt_logprobs: Optional[PromptLogprobs] = None
1159
    if is_prompt and sampling_params.prompt_logprobs is not None:
1160
1161
1162
        prompt_logprobs = []
        num_logprobs = sampling_params.prompt_logprobs
        next_prompt_tokens = _get_next_prompt_tokens(seq_group)
1163
1164
1165
1166
1167
1168
1169
1170
1171
        # Pre-select indexes and create a list. It is faster than calling .item
        # repetitively.
        selected_logprob_items = selected_logprobs[
            selected_logprobs_idx:selected_logprobs_idx +
            len(next_prompt_tokens)].tolist()
        rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                           len(next_prompt_tokens)].tolist()

        for idx, token_id in enumerate(next_prompt_tokens):
1172
1173
1174
            # Calculate the prompt logprob of the real prompt tokens.
            # {token_id: (logprob, rank_from_vocab)}
            prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
1175
                token_id: (selected_logprob_items[idx], rank_items[idx])
1176
            }
1177

1178
1179
            # Add top K prompt logprobs along with its rank.
            if num_logprobs > 0:
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
                top_ids = top_token_ids[
                    top_logprob_idx, :num_logprobs].tolist()
                top_probs = top_logprobs[
                    top_logprob_idx, :num_logprobs].tolist()
                # Top K is already sorted by rank, so we can use 1 ~
                # num_logprobs + 1 for rank.
                top_ranks = range(1, num_logprobs + 1)
                prompt_logprobs_dict.update({
                    top_id: (top_prob, rank)
                    for top_id, top_prob, rank in zip(top_ids, top_probs,
                                                      top_ranks)
                })
1192
1193
1194
1195
1196
1197
            prompt_logprobs.append({
                token_id: Logprob(*logprob_and_rank)
                for token_id, logprob_and_rank in prompt_logprobs_dict.items()
            })
            # + 1 to go to the next prompt token.
            top_logprob_idx += 1
1198
1199
1200

        # + len(next_prompt_tokens) to go to the next prompt.
        selected_logprobs_idx += len(next_prompt_tokens)
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
    return prompt_logprobs, top_logprob_idx, selected_logprobs_idx


def _get_sampled_logprob_if_needed(
    seq_group: SequenceGroupToSample,
    sample_result: Tuple[List[int], List[int]],
    selected_logprobs: torch.Tensor,
    ranks: torch.Tensor,
    top_token_ids: torch.Tensor,
    top_logprobs: torch.Tensor,
    selected_logprobs_idx: int,
    top_logprob_idx: int,
):
    """Compute the sample logprob if needed."""
    seq_ids = seq_group.seq_ids
1216
1217
    num_logprobs = seq_group.sampling_params.logprobs
    use_beam_search = seq_group.sampling_params.use_beam_search
1218
1219
1220
1221
1222
    sampled_logprobs: SampleLogprobs = []
    next_token_ids, parent_seq_ids = sample_result

    if seq_group.do_sample:
        assert len(next_token_ids) > 0
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
        if num_logprobs is None and not use_beam_search:
            for next_token_id in next_token_ids:
                # Use a dummy logprob
                sampled_logprobs.append({next_token_id: Logprob(inf)})
        else:
            # Pre-select items from tensor. tolist() is faster than repetitive
            # `.item()` calls.
            selected_logprob_items = selected_logprobs[
                selected_logprobs_idx:selected_logprobs_idx +
                len(next_token_ids)].tolist()
            rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
                               len(next_token_ids)].tolist()
            for idx, (next_token_id, parent_id) in enumerate(
                    zip(next_token_ids, parent_seq_ids)):
                # Get the logprob of a sampled token.
                sampled_logprobs_dict = {
                    next_token_id:
                    (selected_logprob_items[idx], rank_items[idx])
                }
                if num_logprobs is not None and num_logprobs > 0:
                    # Get top K logprobs.
                    top_ids = top_token_ids[top_logprob_idx +
                                            parent_id, :num_logprobs].tolist()
                    top_probs = top_logprobs[
                        top_logprob_idx + parent_id, :num_logprobs].tolist()
                    # Top K is already sorted by rank, so we can use 1 ~
                    # num_logprobs + 1 for rank.
                    top_ranks = range(1, num_logprobs + 1)
                    sampled_logprobs_dict.update({
                        top_id: (top_prob, rank)
                        for top_id, top_prob, rank in zip(
                            top_ids, top_probs, top_ranks)
                    })

                sampled_logprobs.append({
                    token_id: Logprob(*logprob_and_rank)
                    for token_id, logprob_and_rank in
                    sampled_logprobs_dict.items()
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
                })

        # NOTE: This part of code is not intuitive. `selected_logprobs` include
        # logprobs for the current step, which has len(next_token_ids) tokens
        # per sequence group. `logprobs` includes logprobs from the previous
        # steps, which has len(seq_ids) tokens per sequence group.

        # Iterate to the next sequence group in a batch.
        selected_logprobs_idx += len(next_token_ids)
        # Iterate to the next sequence group in a batch.
1271
1272
        top_logprob_idx += len(seq_ids)
    return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
1273
1274


1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
                                 sample_indices: torch.Tensor,
                                 greedy_samples: torch.Tensor) -> None:
    """Modify the probability distributions of the greedily-sampled tokens such
    that each sampled token has a "probability" of 1.0. This is required by
    speculative decoding, which depends on the sampling method being encoded
    within the probability distribution for correctness.

    # Why do we only need to do this for greedy sampling?

    vLLM's sampler performs the following steps for greedy or multinomial
    (random) sampling:
        1. Get logits from model.
        2. Modify logits according to per-sequence sampling parameters.
            - Multiply by temperature, top-k and top-p masking, penalize tokens
                according to their frequency, etc.
        3. Sample a token.
            - Random sampling simply samples from the modified probability
                distribution.
            - Greedy sampling performs `argmax` to obtain the token with the
                highest likelihood.
1296

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
    Ignoring greedy sampling for a moment, we find that the computed probability
    distribution has the following property: we can sample from it independently
    and find that the token sampled by the Sampler has a frequency corresponding
    to how often we see it in our sampling. In other words, for tokens sampled
    with vLLM's random SamplingType, the computed probability distribution
    encodes the sampling methodology completely.

    Greedy sampling does not normally have this property. vLLM modifies logits
    according to sampling params, then performs `argmax`, then returns the
    sampled token and the computed probability distribution. If we sample from
    the distribution, we'll find the likelihood of the greedily-sampled token
    is not always 1.0.

    Since lossless speculative decoding requires that the sampling methodology
    be encoded within the probability distribution, we are motivated to modify
    the probability distribution such that the sampled token has probability 1
    when speculative decoding is used.

    NOTE: Alternatively, we could use an extremely low temperature to achieve
    greedy sampling using multinomial computation and unite the codepaths. This
    has implications on the overall design of the sampler, e.g. how to record
    accurate logprobs for the user, so this improvement is deferred to later.
    """
1320
    # NOTE: logprobs are not modified so they can be returned to the user.
1321
1322
1323
1324
    probs[sample_indices, :] = 0
    probs[sample_indices, greedy_samples] = 1.0


1325
def _build_sampler_output(
1326
    maybe_deferred_sample_results: MaybeDeferredSampleResultType,
1327
    sampling_metadata: SamplingMetadata,
1328
1329
    prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
    sample_logprobs: Optional[List[SampleLogprobs]],
1330
1331
    on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
                                      torch.Tensor]],
1332
    skip_sampler_cpu_output: bool = False,
1333
) -> SamplerOutput:
1334
1335
1336
1337
1338
1339
1340
1341
    """Construct Python objects with the output of sampling.

    Args:
        on_device_tensors: Tuple containing on-device tensors with the
            probabilities used in sampling and the sampled token ids. This
            allows post-processing without copies to CPU/serialization, e.g. in
            speculative decoding rejection sampling.
    """
1342
    sampler_output: List[CompletionSequenceGroupOutput] = []
1343
1344
1345
1346
1347

    if skip_sampler_cpu_output:
        assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
        deferred_sample_results_args = maybe_deferred_sample_results
    else:
1348
1349
        assert prompt_logprobs is not None
        assert sample_logprobs is not None
1350
1351
1352
        assert not isinstance(maybe_deferred_sample_results,
                              SampleResultArgsType)
        deferred_sample_results_args = None
1353
1354
1355

        for (seq_group, sample_result, group_prompt_logprobs,
             group_sample_logprobs) in zip(sampling_metadata.seq_groups,
1356
1357
                                           maybe_deferred_sample_results,
                                           prompt_logprobs, sample_logprobs):
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
            seq_ids = seq_group.seq_ids
            next_token_ids, parent_ids = sample_result
            seq_outputs: List[SequenceOutput] = []
            for parent_id, next_token_id, logprobs in zip(
                    parent_ids, next_token_ids, group_sample_logprobs):
                seq_outputs.append(
                    SequenceOutput(seq_ids[parent_id], next_token_id,
                                   logprobs))
            sampler_output.append(
                CompletionSequenceGroupOutput(seq_outputs,
                                              group_prompt_logprobs))
1369
1370
1371

    # If not specified, store None values in SamplerOutput.
    if on_device_tensors is not None:
1372
1373
        (sampled_token_probs, logprobs_tensor,
         sampled_token_ids) = on_device_tensors
1374
    else:
1375
1376
        sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
                                                                   None)
1377
1378
1379
1380
1381

    return SamplerOutput(
        outputs=sampler_output,
        sampled_token_probs=sampled_token_probs,
        sampled_token_ids=sampled_token_ids,
1382
        logprobs=logprobs_tensor,
1383
        deferred_sample_results_args=deferred_sample_results_args)
1384
1385


1386
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
    """Get a list of next prompt tokens to compute logprob from a
        given sequence group.

    It is used to compute prompt logprob. Imagine you have logprob for each
    query token. Query token needs to know the next prompt token id to compute
    prompt logprob. This is a helper to obtain next prompt token ids.

    This API has to be used only when the caller knows seq_group is in prefill
    stage.

    Returns:
        A list of next prompt tokens to compute logprob.
    """
    assert seq_group.is_prompt, (
        "Caller should ensure the sequence group is in a prefill stage.")
    seq_ids = seq_group.seq_ids
1403
1404
    query_len = seq_group.query_len
    assert query_len is not None
1405
1406
1407
1408
1409
1410
1411
    # prompt has only 1 seq id.
    assert len(seq_ids) == 1
    seq_data = seq_group.seq_data[seq_ids[0]]
    computed_len = seq_data.get_num_computed_tokens()
    prompt_tokens = seq_data.prompt_token_ids
    # +1 because we are looking for a next prompt token.
    next_token_index_start = computed_len + 1
1412
    next_token_index_end = min(computed_len + query_len + 1,
1413
1414
1415
1416
                               len(prompt_tokens))
    next_prompt_tokens = prompt_tokens[
        next_token_index_start:next_token_index_end]
    return next_prompt_tokens