topk_topp_sampler.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6

import torch
import torch.nn as nn
7
from packaging import version
8
9

from vllm import envs
10
from vllm.config.model import LogprobsMode
11
from vllm.logger import init_logger
12
from vllm.platforms import CpuArchEnum, current_platform
13
14
15
16
17

logger = init_logger(__name__)


class TopKTopPSampler(nn.Module):
18
19
20
21
22
23
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
24

25
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
26
        super().__init__()
27
28
29
        self.logprobs_mode = logprobs_mode
        # flashinfer optimization does not apply if intermediate
        # logprobs/logits after top_k/top_p need to be returned
30
31
32
33
        if (
            logprobs_mode not in ("processed_logits", "processed_logprobs")
            and current_platform.is_cuda()
        ):
34
35
36
37
            if envs.VLLM_USE_FLASHINFER_SAMPLER:
                # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
                logger.info_once("Using FlashInfer for top-p & top-k sampling.")
                self.forward = self.forward_cuda
38
            else:
39
40
41
42
                logger.debug_once(
                    "FlashInfer top-p/top-k sampling is available but disabled "
                    "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
                    "after verifying accuracy for your workloads."
43
                )
44
                self.forward = self.forward_native
45

46
        elif current_platform.is_cpu():
47
48
49
50
51
            arch = current_platform.get_cpu_architecture()
            # Fall back to native implementation for POWERPC and RISCV.
            # On PowerPC argmax produces incorrect output with torch.compile.
            # PR: https://github.com/vllm-project/vllm/pull/26987
            if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
52
53
54
                self.forward = self.forward_native
            else:
                self.forward = self.forward_cpu
55
56
        else:
            self.forward = self.forward_native
57
58

        self.apply_top_k_top_p = apply_top_k_top_p
59
60
61
62

    def forward_native(
        self,
        logits: torch.Tensor,
63
        generators: dict[int, torch.Generator],
64
65
66
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
67
68
69
70
71
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
72
73
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
74
        if self.logprobs_mode == "processed_logits":
75
            logits_to_return = logits
76
        elif self.logprobs_mode == "processed_logprobs":
77
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
78
        probs = logits.softmax(dim=-1, dtype=torch.float32)
79
        return random_sample(probs, generators), logits_to_return
80
81
82
83

    def forward_cuda(
        self,
        logits: torch.Tensor,
84
        generators: dict[int, torch.Generator],
85
86
87
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
88
        """More optimized implementation for top-k and top-p sampling."""
89
90
91
92
93
        # We prefer `random_sample` over `flashinfer_sample` when sorting is
        # not needed. This is because `random_sample` does not require
        # CPU-GPU synchronization while `flashinfer_sample` does.
        if (k is None and p is None) or generators:
            if generators:
94
95
96
97
98
                logger.debug_once(
                    "FlashInfer 0.2.3+ does not support "
                    "per-request generators. Falling back to "
                    "PyTorch-native implementation."
                )
99
            return self.forward_native(logits, generators, k, p)
100
101
102
        assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
            "FlashInfer does not support returning logits/logprobs"
        )
103
104
105
        # flashinfer sampling functions expect contiguous logits.
        # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
        # because of slicing operation in logits_processor.
106
        return flashinfer_sample(logits.contiguous(), k, p, generators), None
107

108
109
110
111
    def forward_cpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
112
113
114
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        """
        PyTorch-native implementation of top-k and top-p sampling for CPU.

        The logits tensor may be updated in-place.
        """
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
        if self.logprobs_mode == "processed_logits":
            logits_to_return = logits
        elif self.logprobs_mode == "processed_logprobs":
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

        # Note: this is a workaround for
        # https://github.com/pytorch/pytorch/pull/151218
        @torch.compile(dynamic=True)
        def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            return probs.div(q).argmax(dim=-1).view(-1)

        if len(generators) != logits.shape[0]:
            return compiled_random_sample(logits), logits_to_return
        else:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            for i, generator in generators.items():
                q[i].exponential_(generator=generator)

            return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

147
148
149

def apply_top_k_top_p(
    logits: torch.Tensor,
150
151
    k: torch.Tensor | None,
    p: torch.Tensor | None,
152
153
154
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

155
156
157
158
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
159
    """
160
161
162
163
164
165
166
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

167
168
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

169
    if k is not None:
170
        # Apply top-k.
171
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
172
173
174
175
176
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

177
    if p is not None:
178
179
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
180
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
181
182
183
184
185
186
187
188
189
190
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
208
    k_index = k.sub_(1).unsqueeze(1)
209
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
210
211
212
213
214
215
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


216
217
def random_sample(
    probs: torch.Tensor,
218
    generators: dict[int, torch.Generator],
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
241
    logits: torch.Tensor,
242
243
    k: torch.Tensor | None,
    p: torch.Tensor | None,
244
    generators: dict[int, torch.Generator],
245
) -> torch.Tensor:
246
    """Sample from the logits using FlashInfer.
247
248
249
250

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
251

252
253
254
255
256
257
258
259
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
260
261
262
263
264
265
266
    import flashinfer

    if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
        raise ImportError(
            "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
        )

267
268
    assert not (k is None and p is None)
    if k is None:
269
        # Top-p only.
270
        probs = logits.softmax(dim=-1, dtype=torch.float32)
271
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
272
273
            probs, p, deterministic=True
        )
274
    elif p is None:
275
        # Top-k only.
276
        probs = logits.softmax(dim=-1, dtype=torch.float32)
277
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
278
279
            probs, k, deterministic=True
        )
280
281
    else:
        # Both top-k and top-p.
282
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
283
284
            logits, k, p, deterministic=True
        )
285

286
    return next_token_ids.view(-1)