topk_topp_sampler.py 7.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Optional
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

import torch
import torch.nn as nn

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

try:
    import flashinfer.sampling
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False


class TopKTopPSampler(nn.Module):

    def __init__(self):
        super().__init__()
25
        if current_platform.is_cuda():
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            if is_flashinfer_available:
                if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
                    # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
                    # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
                    # default it is unused). For backward compatibility, we set
                    # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
                    # interpret it differently in V0 and V1 samplers: In V0,
                    # None means False, while in V1, None means True. This is
                    # why we use the condition
                    # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
                    logger.info("Using FlashInfer for top-p & top-k sampling.")
                    self.forward = self.forward_cuda
                else:
                    logger.warning(
                        "FlashInfer is available, but it is not enabled. "
                        "Falling back to the PyTorch-native implementation of "
                        "top-p & top-k sampling. For the best performance, "
                        "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
                    self.forward = self.forward_native
            else:
                logger.warning(
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of top-p & top-k sampling. For the "
Kazuhiro Serizawa's avatar
Kazuhiro Serizawa committed
49
                    "best performance, please install FlashInfer.")
50
51
52
53
54
55
56
                self.forward = self.forward_native
        else:
            self.forward = self.forward_native

    def forward_native(
        self,
        logits: torch.Tensor,
57
        generators: dict[int, torch.Generator],
58
59
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
60
61
    ) -> torch.Tensor:
        """PyTorch-native implementation of top-k and top-p sampling."""
62
        logits = apply_top_k_top_p(logits, k, p)
63
64
65
66
67
68
        probs = logits.softmax(dim=-1, dtype=torch.float32)
        return random_sample(probs, generators)

    def forward_cuda(
        self,
        logits: torch.Tensor,
69
        generators: dict[int, torch.Generator],
70
71
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
72
73
74
    ) -> torch.Tensor:
        """More optimized implementation for top-k and top-p sampling."""
        probs = logits.softmax(dim=-1, dtype=torch.float32)
75
        if k is None and p is None:
76
77
78
79
            # We prefer `random_sample` over `flashinfer_sample` when sorting is
            # not needed. This is because `random_sample` does not require
            # CPU-GPU synchronization while `flashinfer_sample` does.
            return random_sample(probs, generators)
80
        return flashinfer_sample(probs, k, p, generators)
81
82
83
84


def apply_top_k_top_p(
    logits: torch.Tensor,
85
86
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
87
88
89
90
91
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

    This function sorts the logits tensor, which can be slow for large batches.
    """
92
    if k is None and p is None:
93
94
95
        return logits
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

96
    if k is not None:
97
98
99
100
101
102
103
        # Apply top-k.
        top_k_mask = logits_sort.size(1) - k.to(torch.long)
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

104
    if p is not None:
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
        probs_sum = probs_sort.cumsum(dim=-1)
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


def random_sample(
    probs: torch.Tensor,
120
    generators: dict[int, torch.Generator],
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
    probs: torch.Tensor,
144
145
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
146
    generators: dict[int, torch.Generator],
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
) -> torch.Tensor:
    """Sample from the probabilities using FlashInfer.

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
    
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
162
    assert not (k is None and p is None)
163
164
165
166
167
168
169
170
171
172
    max_top_k_round = 32
    batch_size = probs.shape[0]
    uniform_samples = torch.empty((max_top_k_round, batch_size),
                                  device=probs.device)
    if len(generators) != batch_size:
        uniform_samples.uniform_()
    if generators:
        for i, generator in generators.items():
            uniform_samples[:, i].uniform_(generator=generator)

173
    if k is None:
174
175
176
        # Top-p only.
        next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
            probs, uniform_samples, p, deterministic=True)
177
    elif p is None:
178
179
180
181
182
183
184
185
186
187
188
        # Top-k only.
        next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
            probs, uniform_samples, k, deterministic=True)
    else:
        # Both top-k and top-p.
        next_token_ids, success = (
            flashinfer.sampling.top_k_top_p_sampling_from_probs(
                probs, uniform_samples, k, p, deterministic=True))

    # NOTE: CPU-GPU synchronization happens here.
    if not success.all():
189
        if k is not None:
190
            probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
191
        if p is not None:
192
193
194
195
            probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
        next_token_ids = flashinfer.sampling.sampling_from_probs(
            probs, uniform_samples[0], deterministic=True)
    return next_token_ids.view(-1)