topk_topp_sampler.py 19.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6

import torch
import torch.nn as nn
7
from packaging import version
8
9

from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.config.model import LogprobsMode
12
from vllm.logger import init_logger
13
from vllm.platforms import CpuArchEnum, current_platform
14

15
16
17
18
19
20
21
22
HAS_LIGHTOP_OPT_KERNEL = True
try:
    from lightop.sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs_lightop
    from lightop.sampling import top_k_sampling_from_probs as top_k_sampling_from_probs_lightop
    from lightop.sampling import top_p_sampling_from_probs as top_p_sampling_from_probs_lightop
except ImportError:
    HAS_LIGHTOP_OPT_KERNEL = False

23
24
25
26
logger = init_logger(__name__)


class TopKTopPSampler(nn.Module):
27
28
29
30
31
32
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
33

34
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
35
        super().__init__()
36
37
38
        self.logprobs_mode = logprobs_mode
        # flashinfer optimization does not apply if intermediate
        # logprobs/logits after top_k/top_p need to be returned
39
40
41
42
        if (
            logprobs_mode not in ("processed_logits", "processed_logprobs")
            and current_platform.is_cuda()
        ):
43
            if envs.VLLM_USE_FLASHINFER_SAMPLER:
44
45
46
47
48
49
50
51
52
53
                from vllm.v1.attention.backends.flashinfer import FlashInferBackend

                capability = current_platform.get_device_capability()
                assert capability is not None
                if not FlashInferBackend.supports_compute_capability(capability):
                    capability_str = capability.as_version_str()
                    raise RuntimeError(
                        "FlashInfer does not support compute capability "
                        f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
                    )
54
                # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
55
56
57
58
                logger.info_once(
                    "Using FlashInfer for top-p & top-k sampling.",
                    scope="global",
                )
59
                self.forward = self.forward_cuda
60
            else:
61
62
63
64
                logger.debug_once(
                    "FlashInfer top-p/top-k sampling is available but disabled "
                    "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
                    "after verifying accuracy for your workloads."
65
                )
66
                self.forward = self.forward_native
67

68
        elif current_platform.is_cpu():
69
70
71
72
73
            arch = current_platform.get_cpu_architecture()
            # Fall back to native implementation for POWERPC and RISCV.
            # On PowerPC argmax produces incorrect output with torch.compile.
            # PR: https://github.com/vllm-project/vllm/pull/26987
            if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
74
75
76
                self.forward = self.forward_native
            else:
                self.forward = self.forward_cpu
77
78
79
80
        elif (
            logprobs_mode not in ("processed_logits", "processed_logprobs")
            and rocm_aiter_ops.is_enabled()
        ):
81
82
            try:
                import aiter.ops.sampling  # noqa: F401
83

84
85
86
87
88
89
90
91
92
93
94
                self.aiter_ops = torch.ops.aiter
                logger.info_once(
                    "Using aiter sampler on ROCm (lazy import, sampling-only)."
                )
                self.forward = self.forward_hip
            except ImportError:
                logger.warning_once(
                    "aiter.ops.sampling is not available on ROCm. "
                    "Falling back to forward_native implementation."
                )
                self.forward = self.forward_native
95
96
        else:
            self.forward = self.forward_native
97
98
            if HAS_LIGHTOP_OPT_KERNEL:
                self.forward = self.forward_lightop_opt
99
100

        self.apply_top_k_top_p = apply_top_k_top_p
101
102
103
104

    def forward_native(
        self,
        logits: torch.Tensor,
105
        generators: dict[int, torch.Generator],
106
107
        k: torch.Tensor | None,
        p: torch.Tensor | None,
108
109
110
        *,
        max_top_k: int | None = None,
        has_any_no_top_k: bool = False,
111
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
112
113
114
115
116
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        # Fast path: when top-k is enabled, avoid full-vocab sort/softmax by
        # sampling from only the reduced candidate set.
        if (
            self.logprobs_mode not in ("processed_logits", "processed_logprobs")
            and envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
            and k is not None
            and p is not None
            and max_top_k is not None
            and not has_any_no_top_k
            and max_top_k <= 4096
        ):
            try:
                return (
                    sample_top_k_top_p_reduced(
                        logits,
                        generators,
                        k,
                        p,
                        max_top_k=max_top_k,
                    ),
                    None,
                )
            except Exception:
                # Fall back to the reference implementation for safety.
                logger.debug_once(
                    "Reduced top-k/top-p sampler failed; falling back to the "
                    "reference implementation."
                )

146
147
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
148
        if self.logprobs_mode == "processed_logits":
149
            logits_to_return = logits
150
        elif self.logprobs_mode == "processed_logprobs":
151
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
152
        probs = logits.softmax(dim=-1, dtype=torch.float32)
153
        return random_sample(probs, generators), logits_to_return
154
155
156
157

    def forward_cuda(
        self,
        logits: torch.Tensor,
158
        generators: dict[int, torch.Generator],
159
160
161
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
162
        """More optimized implementation for top-k and top-p sampling."""
163
164
165
166
167
        # We prefer `random_sample` over `flashinfer_sample` when sorting is
        # not needed. This is because `random_sample` does not require
        # CPU-GPU synchronization while `flashinfer_sample` does.
        if (k is None and p is None) or generators:
            if generators:
168
169
170
171
172
                logger.debug_once(
                    "FlashInfer 0.2.3+ does not support "
                    "per-request generators. Falling back to "
                    "PyTorch-native implementation."
                )
173
            return self.forward_native(logits, generators, k, p)
174
175
176
        assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
            "FlashInfer does not support returning logits/logprobs"
        )
177
178
179
        # flashinfer sampling functions expect contiguous logits.
        # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
        # because of slicing operation in logits_processor.
180
        return flashinfer_sample(logits.contiguous(), k, p, generators), None
181
182
183
184
185
186
187
188
189
190
191
192
193
    
    def forward_lightop_opt(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Top-k and top-p sampling optimized by lightop."""
        if (k is None and p is None) or generators:
            return self.forward_native(logits, generators, k, p)

        return lightop_sample(logits.contiguous(), k, p, generators), None
194

195
196
197
198
    def forward_cpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
199
200
201
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        """
        PyTorch-native implementation of top-k and top-p sampling for CPU.

        The logits tensor may be updated in-place.
        """
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
        if self.logprobs_mode == "processed_logits":
            logits_to_return = logits
        elif self.logprobs_mode == "processed_logprobs":
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)

        if len(generators) != logits.shape[0]:
            return compiled_random_sample(logits), logits_to_return
        else:
            probs = logits.softmax(dim=-1, dtype=torch.float32)
            q = torch.empty_like(probs)
            q.exponential_()
            for i, generator in generators.items():
                q[i].exponential_(generator=generator)

            return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

225
226
227
228
229
230
231
    def forward_hip(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
        k: torch.Tensor | None,
        p: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
232
233
        # FIXME: Fix aiter_sampler's accuracy issue and remove this flag
        DISABLE_AITER_SAMPLER = True
234
235
236
237
238
239
240
241
242
243
244
245
        """Optimized ROCm/aiter path (same structure as forward_cuda)."""
        if (k is None and p is None) or generators:
            if generators:
                logger.warning_once(
                    "aiter sampler does not support per-request generators; "
                    "falling back to PyTorch-native."
                )
            return self.forward_native(logits, generators, k, p)
        assert self.logprobs_mode not in (
            "processed_logits",
            "processed_logprobs",
        ), "aiter sampler does not support returning logits/logprobs."
246
247
        if DISABLE_AITER_SAMPLER:
            return self.forward_native(logits, generators, k, p)
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        return self.aiter_sample(logits, k, p, generators), None

    def aiter_sample(
        self,
        logits: torch.Tensor,
        k: torch.Tensor | None,
        p: torch.Tensor | None,
        generators: dict[int, torch.Generator],
    ) -> torch.Tensor:
        """Sample from logits using aiter ops."""
        use_top_k = k is not None
        use_top_p = p is not None
        # Joint k+p path
        if use_top_p and use_top_k:
            probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
            next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
                probs,
                None,
                *_to_tensor_scalar_tuple(k),
                *_to_tensor_scalar_tuple(p),
                deterministic=True,
            )
            return next_token_ids.view(-1)
        # Top-p only path
        elif use_top_p:
            probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
            next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
                probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
            )
            return next_token_ids.view(-1)
        # Top-k only path
        elif use_top_k:
            probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
            renorm_probs = self.aiter_ops.top_k_renorm_probs(
                probs, *_to_tensor_scalar_tuple(k)
            )
            return torch.multinomial(renorm_probs, num_samples=1).view(-1)
        raise RuntimeError("aiter_sample was called with no active top-k or top-p.")

287

288
289
290
291
292
293
294
295
296
297
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    q = torch.empty_like(probs)
    q.exponential_()
    return probs.div(q).argmax(dim=-1).view(-1)


298
299
def apply_top_k_top_p(
    logits: torch.Tensor,
300
301
    k: torch.Tensor | None,
    p: torch.Tensor | None,
302
303
304
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

305
306
307
308
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
309
    """
310
311
312
313
314
315
316
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

317
318
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

319
    if k is not None:
320
        # Apply top-k.
321
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
322
323
324
325
326
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

327
    if p is not None:
328
329
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
330
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
331
332
333
334
335
336
337
338
339
340
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
358
    k_index = k.sub_(1).unsqueeze(1)
359
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
360
361
362
363
364
365
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


366
367
def random_sample(
    probs: torch.Tensor,
368
    generators: dict[int, torch.Generator],
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def sample_top_k_top_p_reduced(
    logits: torch.Tensor,
    generators: dict[int, torch.Generator],
    k: torch.Tensor,
    p: torch.Tensor | None,
    *,
    max_top_k: int,
) -> torch.Tensor:
    """Sample logits from only the top-k candidate set."""
    vocab_size = logits.shape[-1]
    # Guard for extreme values that can defeat the purpose of this fast path.
    if max_top_k <= 0 or max_top_k >= vocab_size:
        masked_logits = apply_top_k_top_p(logits, k, p)
        probs = masked_logits.softmax(dim=-1, dtype=torch.float32)
        return random_sample(probs, generators)

    topk = logits.topk(max_top_k, dim=-1)
    topk_logits = topk.values
    topk_indices = topk.indices

    # Apply per-row top-k on the reduced candidate set.
    k = k.to(torch.long)
    arange_k = torch.arange(max_top_k, device=logits.device).unsqueeze(0)
    keep_k = arange_k < k.unsqueeze(1)
    topk_logits = topk_logits.masked_fill(~keep_k, -float("inf"))

    # Convert to probabilities over the reduced candidate set.
    probs = topk_logits.softmax(dim=-1, dtype=torch.float32)

    if p is not None:
        # Apply top-p in descending-logit order within the reduced set.
        cumprob = torch.cumsum(probs, dim=-1)
        cumprob_prev = cumprob - probs
        keep_p = cumprob_prev < p.unsqueeze(1)
        probs = probs * keep_p

    # Sample position in reduced set, then map back to vocab ids.
    pos = random_sample(probs, generators)
    return topk_indices.gather(1, pos.unsqueeze(1)).squeeze(1)


431
def flashinfer_sample(
432
    logits: torch.Tensor,
433
434
    k: torch.Tensor | None,
    p: torch.Tensor | None,
435
    generators: dict[int, torch.Generator],
436
) -> torch.Tensor:
437
    """Sample from the logits using FlashInfer.
438
439
440
441

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
442

443
444
445
446
447
448
449
450
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
451
452
453
454
455
456
457
    import flashinfer

    if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
        raise ImportError(
            "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
        )

458
459
    assert not (k is None and p is None)
    if k is None:
460
        # Top-p only.
461
        probs = logits.softmax(dim=-1, dtype=torch.float32)
462
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
463
464
            probs, p, deterministic=True
        )
465
    elif p is None:
466
        # Top-k only.
467
        probs = logits.softmax(dim=-1, dtype=torch.float32)
468
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
469
470
            probs, k, deterministic=True
        )
471
472
    else:
        # Both top-k and top-p.
473
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
474
475
            logits, k, p, deterministic=True
        )
476

477
    return next_token_ids.view(-1)
478

479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def lightop_sample(
    logits: torch.Tensor,
    k: torch.Tensor | None,
    p: torch.Tensor | None,
    generators: dict[int, torch.Generator],
) -> torch.Tensor:
    """Sample from the logits using lightop.

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.

    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
    assert not (k is None and p is None)
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    if k is None:
        # Top-p only.
        next_token_ids = top_p_sampling_from_probs_lightop(
            probs, p, deterministic=True
        )
    elif p is None:
        # Top-k only.
        next_token_ids = top_k_sampling_from_probs_lightop(
            probs, k, deterministic=True
        )
    else:
        # Both top-k and top-p.
        next_token_ids = top_k_top_p_sampling_from_probs_lightop(
            probs, k, p, deterministic=True
        )

    return next_token_ids.view(-1)

519
520
521
522
523
524

def _to_tensor_scalar_tuple(x):
    if isinstance(x, torch.Tensor):
        return (x, 0)
    else:
        return (None, x)