topk_topp_sampler.py 10.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Optional
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch
import torch.nn as nn

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

try:
    import flashinfer.sampling
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False


class TopKTopPSampler(nn.Module):
22
23
24
25
26
27
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
28
29
30

    def __init__(self):
        super().__init__()
31
        if current_platform.is_cuda():
32
            if is_flashinfer_available:
33
                flashinfer_version = flashinfer.__version__
34
35
36
37
                if flashinfer_version < "0.2.3":
                    logger.warning(
                        "FlashInfer version >= 0.2.3 required. "
                        "Falling back to default sampling implementation.")
38
39
                    self.forward = self.forward_native
                elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
                    # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
                    # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
                    # default it is unused). For backward compatibility, we set
                    # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
                    # interpret it differently in V0 and V1 samplers: In V0,
                    # None means False, while in V1, None means True. This is
                    # why we use the condition
                    # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
                    logger.info("Using FlashInfer for top-p & top-k sampling.")
                    self.forward = self.forward_cuda
                else:
                    logger.warning(
                        "FlashInfer is available, but it is not enabled. "
                        "Falling back to the PyTorch-native implementation of "
                        "top-p & top-k sampling. For the best performance, "
                        "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
                    self.forward = self.forward_native
            else:
                logger.warning(
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of top-p & top-k sampling. For the "
Kazuhiro Serizawa's avatar
Kazuhiro Serizawa committed
61
                    "best performance, please install FlashInfer.")
62
                self.forward = self.forward_native
63
        elif current_platform.is_tpu():
64
            self.forward = self.forward_tpu
65
66
67
68
69
70
        else:
            self.forward = self.forward_native

    def forward_native(
        self,
        logits: torch.Tensor,
71
        generators: dict[int, torch.Generator],
72
73
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
74
    ) -> torch.Tensor:
75
76
77
78
79
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
80
        logits = apply_top_k_top_p(logits, k, p)
81
82
83
84
85
86
        probs = logits.softmax(dim=-1, dtype=torch.float32)
        return random_sample(probs, generators)

    def forward_cuda(
        self,
        logits: torch.Tensor,
87
        generators: dict[int, torch.Generator],
88
89
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
90
91
    ) -> torch.Tensor:
        """More optimized implementation for top-k and top-p sampling."""
92
        if k is None and p is None:
93
94
95
            # We prefer `random_sample` over `flashinfer_sample` when sorting is
            # not needed. This is because `random_sample` does not require
            # CPU-GPU synchronization while `flashinfer_sample` does.
96
            probs = logits.softmax(dim=-1, dtype=torch.float32)
97
            return random_sample(probs, generators)
98
99
100
101
102
        if generators:
            logger.warning("FlashInfer 0.2.3+ does not support "
                           "per-request generators. Falling back to "
                           "PyTorch-native implementation.")
            return self.forward_native(logits, generators, k, p)
103
        return flashinfer_sample(logits, k, p, generators)
104

105
106
107
108
109
110
111
    def forward_tpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
    ) -> torch.Tensor:
112
        logits = apply_top_k_top_p_tpu(logits, k, p)
113
114
115
        probs = logits.softmax(dim=-1, dtype=torch.float32)
        return random_sample(probs, generators)

116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def apply_top_k_top_p_tpu(
    logits: torch.Tensor,
    k: torch.Tensor,
    p: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k and top-p optimized for TPU.

    This algorithm avoids using torch.scatter which is extremely slow on TPU.
    This is achieved by finding a "cut-off" element in the original logit, and
    after thresholding the logit using this cut-off, the remaining elements
    shall constitute the top-p set.

    Note: in the case of tie (i.e. multipple cut-off elements present in the
    logit), all tie elements are included in the top-p set. In other words,
    this function does not break ties. Instead, these tie tokens have equal
    chance of being chosen during final sampling, so we can consider the tie
    being broken then.
    """
136
137
138
    probs = logits.softmax(dim=-1)
    probs_sort, _ = probs.sort(dim=-1, descending=False)

139
    if k is not None:
140
141
142
143
144
145
146
147
148
149
        top_k_count = probs_sort.size(1) - k.to(torch.long)  # shape: (batch, )
        top_k_count = top_k_count.unsqueeze(dim=1)
        top_k_cutoff = probs_sort.gather(-1, top_k_count)

        # Make sure the no top-k rows are no-op.
        no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
        top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))

        elements_to_discard = probs < top_k_cutoff
        logits.masked_fill_(elements_to_discard, -float("inf"))
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    if p is not None:
        cumprob = torch.cumsum(probs_sort, dim=-1)
        top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
        top_p_mask[:, -1] = False  # at least one

        top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
        top_p_cutoff = probs_sort.gather(-1, top_p_count)
        elements_to_discard = probs < top_p_cutoff
        logits.masked_fill_(elements_to_discard, -float("inf"))

    return logits


164
165
def apply_top_k_top_p(
    logits: torch.Tensor,
166
167
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
168
169
170
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

171
172
173
174
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
175
    """
176
177
178
179
180
181
182
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

183
184
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

185
    if k is not None:
186
        # Apply top-k.
187
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
188
189
190
191
192
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

193
    if p is not None:
194
195
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
196
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
197
198
199
200
201
202
203
204
205
206
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
224
    k_index = k.sub_(1).unsqueeze(1)
225
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
226
227
228
229
230
231
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


232
233
def random_sample(
    probs: torch.Tensor,
234
    generators: dict[int, torch.Generator],
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
257
    logits: torch.Tensor,
258
259
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
260
    generators: dict[int, torch.Generator],
261
) -> torch.Tensor:
262
    """Sample from the logits using FlashInfer.
263
264
265
266

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
267

268
269
270
271
272
273
274
275
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
276
277
    assert not (k is None and p is None)
    if k is None:
278
        # Top-p only.
279
        probs = logits.softmax(dim=-1, dtype=torch.float32)
280
281
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
            probs, p, deterministic=True)
282
    elif p is None:
283
        # Top-k only.
284
        probs = logits.softmax(dim=-1, dtype=torch.float32)
285
286
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
            probs, k, deterministic=True)
287
288
    else:
        # Both top-k and top-p.
289
290
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
            logits, k, p, deterministic=True)
291

292
    return next_token_ids.view(-1)