topk_topp_sampler.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6

import torch
import torch.nn as nn
7
from packaging import version
8
9

from vllm import envs
10
from vllm.config.model import LogprobsMode
11
from vllm.logger import init_logger
12
from vllm.platforms import CpuArchEnum, current_platform
13
14
15
16
17

logger = init_logger(__name__)

try:
    import flashinfer.sampling
18

19
20
21
22
23
24
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False


class TopKTopPSampler(nn.Module):
25
26
27
28
29
30
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
31

32
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
33
        super().__init__()
34
35
36
        self.logprobs_mode = logprobs_mode
        # flashinfer optimization does not apply if intermediate
        # logprobs/logits after top_k/top_p need to be returned
37
38
39
40
        if (
            logprobs_mode not in ("processed_logits", "processed_logprobs")
            and current_platform.is_cuda()
        ):
41
            if is_flashinfer_available:
42
                flashinfer_version = flashinfer.__version__
43
                if version.parse(flashinfer_version) < version.parse("0.2.3"):
44
                    logger.warning_once(
45
                        "FlashInfer version >= 0.2.3 required. "
46
47
                        "Falling back to default sampling implementation."
                    )
48
                    self.forward = self.forward_native
49
50
                elif envs.VLLM_USE_FLASHINFER_SAMPLER:
                    # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
51
                    logger.info_once("Using FlashInfer for top-p & top-k sampling.")
52
53
                    self.forward = self.forward_cuda
                else:
54
55
56
57
                    logger.debug_once(
                        "FlashInfer top-p/top-k sampling is available but disabled "
                        "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
                        "after verifying accuracy for your workloads."
58
                    )
59
60
                    self.forward = self.forward_native
            else:
61
                logger.warning_once(
62
63
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of top-p & top-k sampling. For the "
64
65
                    "best performance, please install FlashInfer."
                )
66
                self.forward = self.forward_native
67
        elif current_platform.is_cpu():
68
69
70
71
            if current_platform.get_cpu_architecture() == CpuArchEnum.RISCV:
                self.forward = self.forward_native
            else:
                self.forward = self.forward_cpu
72
73
        else:
            self.forward = self.forward_native
74
75

        self.apply_top_k_top_p = apply_top_k_top_p
76
77
78
79

    def forward_native(
        self,
        logits: torch.Tensor,
80
        generators: dict[int, torch.Generator],
81
82
83
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
84
85
86
87
88
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
89
90
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
91
        if self.logprobs_mode == "processed_logits":
92
            logits_to_return = logits
93
        elif self.logprobs_mode == "processed_logprobs":
94
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
95
        probs = logits.softmax(dim=-1, dtype=torch.float32)
96
        return random_sample(probs, generators), logits_to_return
97
98
99
100

    def forward_cuda(
        self,
        logits: torch.Tensor,
101
        generators: dict[int, torch.Generator],
102
103
104
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
105
        """More optimized implementation for top-k and top-p sampling."""
106
107
108
109
110
        # We prefer `random_sample` over `flashinfer_sample` when sorting is
        # not needed. This is because `random_sample` does not require
        # CPU-GPU synchronization while `flashinfer_sample` does.
        if (k is None and p is None) or generators:
            if generators:
111
112
113
114
115
                logger.debug_once(
                    "FlashInfer 0.2.3+ does not support "
                    "per-request generators. Falling back to "
                    "PyTorch-native implementation."
                )
116
            return self.forward_native(logits, generators, k, p)
117
118
119
        assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
            "FlashInfer does not support returning logits/logprobs"
        )
120
121
122
        # flashinfer sampling functions expect contiguous logits.
        # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
        # because of slicing operation in logits_processor.
123
        return flashinfer_sample(logits.contiguous(), k, p, generators), None
124

125
126
127
128
    def forward_cpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
129
130
131
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        """
        PyTorch-native implementation of top-k and top-p sampling for CPU.

        The logits tensor may be updated in-place.
        """
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
        if self.logprobs_mode == "processed_logits":
            logits_to_return = logits
        elif self.logprobs_mode == "processed_logprobs":
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

        # Note: this is a workaround for
        # https://github.com/pytorch/pytorch/pull/151218
        @torch.compile(dynamic=True)
        def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            return probs.div(q).argmax(dim=-1).view(-1)

        if len(generators) != logits.shape[0]:
            return compiled_random_sample(logits), logits_to_return
        else:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            for i, generator in generators.items():
                q[i].exponential_(generator=generator)

            return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

164
165
166

def apply_top_k_top_p(
    logits: torch.Tensor,
167
168
    k: torch.Tensor | None,
    p: torch.Tensor | None,
169
170
171
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

172
173
174
175
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
176
    """
177
178
179
180
181
182
183
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

184
185
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

186
    if k is not None:
187
        # Apply top-k.
188
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
189
190
191
192
193
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

194
    if p is not None:
195
196
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
197
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
198
199
200
201
202
203
204
205
206
207
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
225
    k_index = k.sub_(1).unsqueeze(1)
226
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
227
228
229
230
231
232
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


233
234
def random_sample(
    probs: torch.Tensor,
235
    generators: dict[int, torch.Generator],
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
258
    logits: torch.Tensor,
259
260
    k: torch.Tensor | None,
    p: torch.Tensor | None,
261
    generators: dict[int, torch.Generator],
262
) -> torch.Tensor:
263
    """Sample from the logits using FlashInfer.
264
265
266
267

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
268

269
270
271
272
273
274
275
276
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
277
278
    assert not (k is None and p is None)
    if k is None:
279
        # Top-p only.
280
        probs = logits.softmax(dim=-1, dtype=torch.float32)
281
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
282
283
            probs, p, deterministic=True
        )
284
    elif p is None:
285
        # Top-k only.
286
        probs = logits.softmax(dim=-1, dtype=torch.float32)
287
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
288
289
            probs, k, deterministic=True
        )
290
291
    else:
        # Both top-k and top-p.
292
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
293
294
            logits, k, p, deterministic=True
        )
295

296
    return next_token_ids.view(-1)