topk_topp_sampler.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6

import torch
import torch.nn as nn
7
from packaging import version
8
9

from vllm import envs
10
from vllm.config.model import LogprobsMode
11
from vllm.logger import init_logger
12
from vllm.platforms import CpuArchEnum, current_platform
13
14
15
16
17

logger = init_logger(__name__)

try:
    import flashinfer.sampling
18

19
20
21
22
23
24
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False


class TopKTopPSampler(nn.Module):
25
26
27
28
29
30
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
31

32
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
33
        super().__init__()
34
35
36
        self.logprobs_mode = logprobs_mode
        # flashinfer optimization does not apply if intermediate
        # logprobs/logits after top_k/top_p need to be returned
37
38
39
40
        if (
            logprobs_mode not in ("processed_logits", "processed_logprobs")
            and current_platform.is_cuda()
        ):
41
            if is_flashinfer_available:
42
                flashinfer_version = flashinfer.__version__
43
                if version.parse(flashinfer_version) < version.parse("0.2.3"):
44
                    logger.warning_once(
45
                        "FlashInfer version >= 0.2.3 required. "
46
47
                        "Falling back to default sampling implementation."
                    )
48
49
                    self.forward = self.forward_native
                elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
50
51
52
53
54
55
56
57
                    # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
                    # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
                    # default it is unused). For backward compatibility, we set
                    # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
                    # interpret it differently in V0 and V1 samplers: In V0,
                    # None means False, while in V1, None means True. This is
                    # why we use the condition
                    # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
58
                    logger.info_once("Using FlashInfer for top-p & top-k sampling.")
59
60
                    self.forward = self.forward_cuda
                else:
61
                    logger.warning_once(
62
63
64
                        "FlashInfer is available, but it is not enabled. "
                        "Falling back to the PyTorch-native implementation of "
                        "top-p & top-k sampling. For the best performance, "
65
66
                        "please set VLLM_USE_FLASHINFER_SAMPLER=1."
                    )
67
68
                    self.forward = self.forward_native
            else:
69
                logger.warning_once(
70
71
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of top-p & top-k sampling. For the "
72
73
                    "best performance, please install FlashInfer."
                )
74
                self.forward = self.forward_native
75
        elif current_platform.is_cpu():
76
77
78
79
            if current_platform.get_cpu_architecture() == CpuArchEnum.RISCV:
                self.forward = self.forward_native
            else:
                self.forward = self.forward_cpu
80
81
        else:
            self.forward = self.forward_native
82
83

        self.apply_top_k_top_p = apply_top_k_top_p
84
85
86
87

    def forward_native(
        self,
        logits: torch.Tensor,
88
        generators: dict[int, torch.Generator],
89
90
91
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
92
93
94
95
96
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
97
98
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
99
        if self.logprobs_mode == "processed_logits":
100
            logits_to_return = logits
101
        elif self.logprobs_mode == "processed_logprobs":
102
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
103
        probs = logits.softmax(dim=-1, dtype=torch.float32)
104
        return random_sample(probs, generators), logits_to_return
105
106
107
108

    def forward_cuda(
        self,
        logits: torch.Tensor,
109
        generators: dict[int, torch.Generator],
110
111
112
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
113
        """More optimized implementation for top-k and top-p sampling."""
114
115
116
117
118
        # We prefer `random_sample` over `flashinfer_sample` when sorting is
        # not needed. This is because `random_sample` does not require
        # CPU-GPU synchronization while `flashinfer_sample` does.
        if (k is None and p is None) or generators:
            if generators:
119
120
121
122
123
                logger.debug_once(
                    "FlashInfer 0.2.3+ does not support "
                    "per-request generators. Falling back to "
                    "PyTorch-native implementation."
                )
124
            return self.forward_native(logits, generators, k, p)
125
126
127
        assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
            "FlashInfer does not support returning logits/logprobs"
        )
128
129
130
        # flashinfer sampling functions expect contiguous logits.
        # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
        # because of slicing operation in logits_processor.
131
        return flashinfer_sample(logits.contiguous(), k, p, generators), None
132

133
134
135
136
    def forward_cpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
137
138
139
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        """
        PyTorch-native implementation of top-k and top-p sampling for CPU.

        The logits tensor may be updated in-place.
        """
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
        if self.logprobs_mode == "processed_logits":
            logits_to_return = logits
        elif self.logprobs_mode == "processed_logprobs":
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

        # Note: this is a workaround for
        # https://github.com/pytorch/pytorch/pull/151218
        @torch.compile(dynamic=True)
        def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            return probs.div(q).argmax(dim=-1).view(-1)

        if len(generators) != logits.shape[0]:
            return compiled_random_sample(logits), logits_to_return
        else:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            for i, generator in generators.items():
                q[i].exponential_(generator=generator)

            return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

172
173
174

def apply_top_k_top_p(
    logits: torch.Tensor,
175
176
    k: torch.Tensor | None,
    p: torch.Tensor | None,
177
178
179
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

180
181
182
183
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
184
    """
185
186
187
188
189
190
191
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

192
193
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

194
    if k is not None:
195
        # Apply top-k.
196
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
197
198
199
200
201
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

202
    if p is not None:
203
204
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
205
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
206
207
208
209
210
211
212
213
214
215
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
233
    k_index = k.sub_(1).unsqueeze(1)
234
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
235
236
237
238
239
240
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


241
242
def random_sample(
    probs: torch.Tensor,
243
    generators: dict[int, torch.Generator],
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
266
    logits: torch.Tensor,
267
268
    k: torch.Tensor | None,
    p: torch.Tensor | None,
269
    generators: dict[int, torch.Generator],
270
) -> torch.Tensor:
271
    """Sample from the logits using FlashInfer.
272
273
274
275

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
276

277
278
279
280
281
282
283
284
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
285
286
    assert not (k is None and p is None)
    if k is None:
287
        # Top-p only.
288
        probs = logits.softmax(dim=-1, dtype=torch.float32)
289
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
290
291
            probs, p, deterministic=True
        )
292
    elif p is None:
293
        # Top-k only.
294
        probs = logits.softmax(dim=-1, dtype=torch.float32)
295
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
296
297
            probs, k, deterministic=True
        )
298
299
    else:
        # Both top-k and top-p.
300
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
301
302
            logits, k, p, deterministic=True
        )
303

304
    return next_token_ids.view(-1)