topk_topp_sampler.py 9.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7

import torch
import torch.nn as nn
8
from packaging import version
9
10

from vllm import envs
11
from vllm.config import LogprobsMode
12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

try:
    import flashinfer.sampling
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False


class TopKTopPSampler(nn.Module):
25
26
27
28
29
30
    """
    Module that performs optional top-k and top-p filtering followed by
    weighted random sampling of logits.

    Implementations may update the logits tensor in-place.
    """
31

32
33
34
    def __init__(
            self,
            logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
35
        super().__init__()
36
37
38
39
40
41
        self.logprobs_mode = logprobs_mode
        # flashinfer optimization does not apply if intermediate
        # logprobs/logits after top_k/top_p need to be returned
        if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS,
                                 LogprobsMode.PROCESSED_LOGPROBS
                                 ) and current_platform.is_cuda():
42
            if is_flashinfer_available:
43
                flashinfer_version = flashinfer.__version__
44
                if version.parse(flashinfer_version) < version.parse("0.2.3"):
45
                    logger.warning_once(
46
47
                        "FlashInfer version >= 0.2.3 required. "
                        "Falling back to default sampling implementation.")
48
49
                    self.forward = self.forward_native
                elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
50
51
52
53
54
55
56
57
                    # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
                    # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
                    # default it is unused). For backward compatibility, we set
                    # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
                    # interpret it differently in V0 and V1 samplers: In V0,
                    # None means False, while in V1, None means True. This is
                    # why we use the condition
                    # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
58
59
                    logger.info_once(
                        "Using FlashInfer for top-p & top-k sampling.")
60
61
                    self.forward = self.forward_cuda
                else:
62
                    logger.warning_once(
63
64
65
66
67
68
                        "FlashInfer is available, but it is not enabled. "
                        "Falling back to the PyTorch-native implementation of "
                        "top-p & top-k sampling. For the best performance, "
                        "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
                    self.forward = self.forward_native
            else:
69
                logger.warning_once(
70
71
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of top-p & top-k sampling. For the "
Kazuhiro Serizawa's avatar
Kazuhiro Serizawa committed
72
                    "best performance, please install FlashInfer.")
73
74
75
                self.forward = self.forward_native
        else:
            self.forward = self.forward_native
76
77

        self.apply_top_k_top_p = apply_top_k_top_p
78
79
80
81

    def forward_native(
        self,
        logits: torch.Tensor,
82
        generators: dict[int, torch.Generator],
83
84
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
85
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
86
87
88
89
90
        """
        PyTorch-native implementation of top-k and top-p sampling.

        The logits tensor may be updated in-place.
        """
91
92
93
94
95
96
        logits = self.apply_top_k_top_p(logits, k, p)
        logits_to_return = None
        if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
            logits_to_return = logits
        elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
            logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
97
        probs = logits.softmax(dim=-1, dtype=torch.float32)
98
        return random_sample(probs, generators), logits_to_return
99
100
101
102

    def forward_cuda(
        self,
        logits: torch.Tensor,
103
        generators: dict[int, torch.Generator],
104
105
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
106
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
107
        """More optimized implementation for top-k and top-p sampling."""
108
109
110
111
112
113
114
115
        # We prefer `random_sample` over `flashinfer_sample` when sorting is
        # not needed. This is because `random_sample` does not require
        # CPU-GPU synchronization while `flashinfer_sample` does.
        if (k is None and p is None) or generators:
            if generators:
                logger.warning_once("FlashInfer 0.2.3+ does not support "
                                    "per-request generators. Falling back to "
                                    "PyTorch-native implementation.")
116
            return self.forward_native(logits, generators, k, p)
117
118
119
        assert self.logprobs_mode not in (
            LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS
        ), "FlashInfer does not support returning logits/logprobs"
120
121
122
        # flashinfer sampling functions expect contiguous logits.
        # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
        # because of slicing operation in logits_processor.
123
        return flashinfer_sample(logits.contiguous(), k, p, generators), None
124

125
126
127

def apply_top_k_top_p(
    logits: torch.Tensor,
128
129
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
130
131
132
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

133
134
135
136
    If a top-p is used, this function will sort the logits tensor,
    which can be slow for large batches.

    The logits tensor may be updated in-place.
137
    """
138
139
140
141
142
143
144
    if p is None:
        if k is None:
            return logits

        # Avoid sorting vocab for top-k only case.
        return apply_top_k_only(logits, k)

145
146
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

147
    if k is not None:
148
        # Apply top-k.
149
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
150
151
152
153
154
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

155
    if p is not None:
156
157
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
158
        probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
159
160
161
162
163
164
165
166
167
168
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def apply_top_k_only(
    logits: torch.Tensor,
    k: torch.Tensor,
) -> torch.Tensor:
    """
    Apply top-k mask to the logits.

    This implementation doesn't involve sorting the entire vocab.

    The logits tensor may be updated in-place.
    """
    no_top_k_mask = k == logits.shape[1]
    # Set non-top-k rows to 1 so that we can gather.
    k = k.masked_fill(no_top_k_mask, 1)
    max_top_k = k.max()
    # topk.values tensor has shape [batch_size, max_top_k].
    # Convert top k to 0-based index in range [0, max_top_k).
186
    k_index = k.sub_(1).unsqueeze(1)
187
    top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
188
189
190
191
192
193
    # Handle non-topk rows.
    top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
    logits.masked_fill_(logits < top_k_mask, -float("inf"))
    return logits


194
195
def random_sample(
    probs: torch.Tensor,
196
    generators: dict[int, torch.Generator],
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
219
    logits: torch.Tensor,
220
221
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
222
    generators: dict[int, torch.Generator],
223
) -> torch.Tensor:
224
    """Sample from the logits using FlashInfer.
225
226
227
228

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
229

230
231
232
233
234
235
236
237
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
238
239
    assert not (k is None and p is None)
    if k is None:
240
        # Top-p only.
241
        probs = logits.softmax(dim=-1, dtype=torch.float32)
242
243
        next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
            probs, p, deterministic=True)
244
    elif p is None:
245
        # Top-k only.
246
        probs = logits.softmax(dim=-1, dtype=torch.float32)
247
248
        next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
            probs, k, deterministic=True)
249
250
    else:
        # Both top-k and top-p.
251
252
        next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
            logits, k, p, deterministic=True)
253

254
    return next_token_ids.view(-1)