topk_topp_sampler.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6

import torch
import torch.nn as nn
7
from packaging import version
8
9

from vllm import envs
10
from vllm.config.model import LogprobsMode
11
from vllm.logger import init_logger
12
from vllm.platforms import CpuArchEnum, current_platform
13
14
15
16
17

logger = init_logger(__name__)


class TopKTopPSampler(nn.Module):
18
19
20
21
22
23
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
24

25
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
26
        super().__init__()
27
28
29
        self.logprobs_mode = logprobs_mode
        # flashinfer optimization does not apply if intermediate
        # logprobs/logits after top_k/top_p need to be returned
30
31
32
33
        if (
            logprobs_mode not in ("processed_logits", "processed_logprobs")
            and current_platform.is_cuda()
        ):
34
35
            if envs.VLLM_USE_FLASHINFER_SAMPLER:
                # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
36
37
38
39
                logger.info_once(
                    "Using FlashInfer for top-p & top-k sampling.",
                    scope="global",
                )
40
                self.forward = self.forward_cuda
41
            else:
42
43
44
45
                logger.debug_once(
                    "FlashInfer top-p/top-k sampling is available but disabled "
                    "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
                    "after verifying accuracy for your workloads."
46
                )
47
                self.forward = self.forward_native
48

49
        elif current_platform.is_cpu():
50
51
52
53
54
            arch = current_platform.get_cpu_architecture()
            # Fall back to native implementation for POWERPC and RISCV.
            # On PowerPC argmax produces incorrect output with torch.compile.
            # PR: https://github.com/vllm-project/vllm/pull/26987
            if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
55
56
57
                self.forward = self.forward_native
            else:
                self.forward = self.forward_cpu
58
59
        else:
            self.forward = self.forward_native
60
61

        self.apply_top_k_top_p = apply_top_k_top_p
62
63
64
65

    def forward_native(
        self,
        logits: torch.Tensor,
66
        generators: dict[int, torch.Generator],
67
68
69
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
70
71
72
73
74
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
75
76
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
77
        if self.logprobs_mode == "processed_logits":
78
            logits_to_return = logits
79
        elif self.logprobs_mode == "processed_logprobs":
80
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
81
        probs = logits.softmax(dim=-1, dtype=torch.float32)
82
        return random_sample(probs, generators), logits_to_return
83
84
85
86

    def forward_cuda(
        self,
        logits: torch.Tensor,
87
        generators: dict[int, torch.Generator],
88
89
90
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
91
        """More optimized implementation for top-k and top-p sampling."""
92
93
94
95
96
        # We prefer `random_sample` over `flashinfer_sample` when sorting is
        # not needed. This is because `random_sample` does not require
        # CPU-GPU synchronization while `flashinfer_sample` does.
        if (k is None and p is None) or generators:
            if generators:
97
98
99
100
101
                logger.debug_once(
                    "FlashInfer 0.2.3+ does not support "
                    "per-request generators. Falling back to "
                    "PyTorch-native implementation."
                )
102
            return self.forward_native(logits, generators, k, p)
103
104
105
        assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
            "FlashInfer does not support returning logits/logprobs"
        )
106
107
108
        # flashinfer sampling functions expect contiguous logits.
        # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
        # because of slicing operation in logits_processor.
109
        return flashinfer_sample(logits.contiguous(), k, p, generators), None
110

111
112
113
114
    def forward_cpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
115
116
117
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        """
        PyTorch-native implementation of top-k and top-p sampling for CPU.

        The logits tensor may be updated in-place.
        """
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
        if self.logprobs_mode == "processed_logits":
            logits_to_return = logits
        elif self.logprobs_mode == "processed_logprobs":
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

        if len(generators) != logits.shape[0]:
            return compiled_random_sample(logits), logits_to_return
        else:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            for i, generator in generators.items():
                q[i].exponential_(generator=generator)

            return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

141

142
143
144
145
146
147
148
149
150
151
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    q = torch.empty_like(probs)
    q.exponential_()
    return probs.div(q).argmax(dim=-1).view(-1)


152
153
def apply_top_k_top_p(
    logits: torch.Tensor,
154
155
    k: torch.Tensor | None,
    p: torch.Tensor | None,
156
157
158
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

159
160
161
162
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
163
    """
164
165
166
167
168
169
170
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

171
172
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

173
    if k is not None:
174
        # Apply top-k.
175
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
176
177
178
179
180
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

181
    if p is not None:
182
183
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
184
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
185
186
187
188
189
190
191
192
193
194
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
212
    k_index = k.sub_(1).unsqueeze(1)
213
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
214
215
216
217
218
219
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


220
221
def random_sample(
    probs: torch.Tensor,
222
    generators: dict[int, torch.Generator],
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
245
    logits: torch.Tensor,
246
247
    k: torch.Tensor | None,
    p: torch.Tensor | None,
248
    generators: dict[int, torch.Generator],
249
) -> torch.Tensor:
250
    """Sample from the logits using FlashInfer.
251
252
253
254

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
255

256
257
258
259
260
261
262
263
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
264
265
266
267
268
269
270
    import flashinfer

    if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
        raise ImportError(
            "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
        )

271
272
    assert not (k is None and p is None)
    if k is None:
273
        # Top-p only.
274
        probs = logits.softmax(dim=-1, dtype=torch.float32)
275
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
276
277
            probs, p, deterministic=True
        )
278
    elif p is None:
279
        # Top-k only.
280
        probs = logits.softmax(dim=-1, dtype=torch.float32)
281
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
282
283
            probs, k, deterministic=True
        )
284
285
    else:
        # Both top-k and top-p.
286
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
287
288
            logits, k, p, deterministic=True
        )
289

290
    return next_token_ids.view(-1)