topk_topp_sampler.py 9.79 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Optional
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

import torch
import torch.nn as nn

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

try:
    import flashinfer.sampling
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False


class TopKTopPSampler(nn.Module):

    def __init__(self):
        super().__init__()
25
        if current_platform.is_cuda():
26
            if is_flashinfer_available:
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
                flashinfer_version = flashinfer.__version__
                if flashinfer_version >= "0.2.3":
                    # FIXME(DefTruth): Currently, we have errors when using
                    # FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
                    # workaround, we disable FlashInfer for top-p & top-k
                    # sampling by default while FlashInfer>=v0.2.3.
                    # The sampling API removes the success return value
                    # of all sampling API, which is not compatible with
                    # earlier design.
                    # https://github.com/flashinfer-ai/flashinfer/releases/
                    # tag/v0.2.3
                    logger.info(
                        "Currently, FlashInfer top-p & top-k sampling sampler "
                        "is disabled because FlashInfer>=v0.2.3 is not "
                        "backward compatible. Falling back to the PyTorch-"
                        "native implementation of top-p & top-k sampling.")
                    self.forward = self.forward_native
                elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
                    # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
                    # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
                    # default it is unused). For backward compatibility, we set
                    # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
                    # interpret it differently in V0 and V1 samplers: In V0,
                    # None means False, while in V1, None means True. This is
                    # why we use the condition
                    # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
                    logger.info("Using FlashInfer for top-p & top-k sampling.")
                    self.forward = self.forward_cuda
                else:
                    logger.warning(
                        "FlashInfer is available, but it is not enabled. "
                        "Falling back to the PyTorch-native implementation of "
                        "top-p & top-k sampling. For the best performance, "
                        "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
                    self.forward = self.forward_native
            else:
                logger.warning(
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of top-p & top-k sampling. For the "
Kazuhiro Serizawa's avatar
Kazuhiro Serizawa committed
66
                    "best performance, please install FlashInfer.")
67
                self.forward = self.forward_native
68
        elif current_platform.is_tpu():
69
70
71
72
73
74
75
76
            if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
                logger.warning(
                    "TPU-specific optimization for top-k & top-p sampling are "
                    "disabled, falling back to PyTorch-native implementation "
                    "which could be very slow.")
                self.forward = self.forward_native
            else:
                self.forward = self.forward_tpu
77
78
79
80
81
82
        else:
            self.forward = self.forward_native

    def forward_native(
        self,
        logits: torch.Tensor,
83
        generators: dict[int, torch.Generator],
84
85
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
86
87
    ) -> torch.Tensor:
        """PyTorch-native implementation of top-k and top-p sampling."""
88
        logits = apply_top_k_top_p(logits, k, p)
89
90
91
92
93
94
        probs = logits.softmax(dim=-1, dtype=torch.float32)
        return random_sample(probs, generators)

    def forward_cuda(
        self,
        logits: torch.Tensor,
95
        generators: dict[int, torch.Generator],
96
97
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
98
99
100
    ) -> torch.Tensor:
        """More optimized implementation for top-k and top-p sampling."""
        probs = logits.softmax(dim=-1, dtype=torch.float32)
101
        if k is None and p is None:
102
103
104
105
            # We prefer `random_sample` over `flashinfer_sample` when sorting is
            # not needed. This is because `random_sample` does not require
            # CPU-GPU synchronization while `flashinfer_sample` does.
            return random_sample(probs, generators)
106
        return flashinfer_sample(probs, k, p, generators)
107

108
109
110
111
112
113
114
    def forward_tpu(
        self,
        logits: torch.Tensor,
        generators: dict[int, torch.Generator],
        k: Optional[torch.Tensor],
        p: Optional[torch.Tensor],
    ) -> torch.Tensor:
115
116
117
118
119
120
121
122
123
124
125
126
127
        # If only top-k is specified, use pytorch's builtin topk op. This leads
        # to significant speed up on TPU compared to using apply_top_k_top_p.
        if k is not None and p is None:
            topk_values, topk_indices = torch.topk(logits, k, dim=-1)

            mask = torch.ones_like(logits, dtype=torch.bool)
            mask.scatter_(-1, topk_indices, False)
            logits.masked_fill_(mask, float('-inf'))
        else:
            # TODO Placeholder for TPU optimized topp kernel
            # logits = apply_top_k_top_p(logits, k, p)
            pass

128
129
130
        probs = logits.softmax(dim=-1, dtype=torch.float32)
        return random_sample(probs, generators)

131
132
133

def apply_top_k_top_p(
    logits: torch.Tensor,
134
135
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
136
137
138
139
140
) -> torch.Tensor:
    """Apply top-k and top-p masks to the logits.

    This function sorts the logits tensor, which can be slow for large batches.
    """
141
    if k is None and p is None:
142
143
144
        return logits
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

145
    if k is not None:
146
        # Apply top-k.
147
        top_k_mask = logits_sort.size(1) - k.to(torch.long)  # shape: B
148
149
150
151
152
        # Get all the top_k values.
        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
        top_k_mask = logits_sort < top_k_mask
        logits_sort.masked_fill_(top_k_mask, -float("inf"))

153
    if p is not None:
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        # Apply top-p.
        probs_sort = logits_sort.softmax(dim=-1)
        probs_sum = probs_sort.cumsum(dim=-1)
        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
        # at least one
        top_p_mask[:, -1] = False
        logits_sort.masked_fill_(top_p_mask, -float("inf"))

    # Re-sort the probabilities.
    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
    return logits


def random_sample(
    probs: torch.Tensor,
169
    generators: dict[int, torch.Generator],
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
) -> torch.Tensor:
    """Randomly sample from the probabilities.

    We use this function instead of torch.multinomial because torch.multinomial
    causes CPU-GPU synchronization.
    """
    q = torch.empty_like(probs)
    # NOTE(woosuk): To batch-process the requests without their own seeds,
    # which is the common case, we first assume that every request does
    # not have its own seed. Then, we overwrite the values for the requests
    # that have their own seeds.
    if len(generators) != probs.shape[0]:
        q.exponential_()
    if generators:
        # TODO(woosuk): This can be slow because we handle each request
        # one by one. Optimize this.
        for i, generator in generators.items():
            q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
    probs: torch.Tensor,
193
194
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
195
    generators: dict[int, torch.Generator],
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
) -> torch.Tensor:
    """Sample from the probabilities using FlashInfer.

    Statistically, this function is equivalent to the `random_sample` function.
    However, this function is faster because it avoids sorting the logits tensor
    via rejection sampling.
    
    NOTE: The outputs of this function do not necessarily match the outputs of
    the `random_sample` function. It only guarantees that the outputs are
    statistically equivalent.

    NOTE: This function includes CPU-GPU synchronization, while `random_sample`
    does not. Call this function at the end of the forward pass to minimize
    the synchronization overhead.
    """
211
    assert not (k is None and p is None)
212
213
214
215
216
217
218
219
220
221
    max_top_k_round = 32
    batch_size = probs.shape[0]
    uniform_samples = torch.empty((max_top_k_round, batch_size),
                                  device=probs.device)
    if len(generators) != batch_size:
        uniform_samples.uniform_()
    if generators:
        for i, generator in generators.items():
            uniform_samples[:, i].uniform_(generator=generator)

222
    if k is None:
223
224
225
        # Top-p only.
        next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
            probs, uniform_samples, p, deterministic=True)
226
    elif p is None:
227
228
229
230
231
232
233
234
235
236
237
        # Top-k only.
        next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
            probs, uniform_samples, k, deterministic=True)
    else:
        # Both top-k and top-p.
        next_token_ids, success = (
            flashinfer.sampling.top_k_top_p_sampling_from_probs(
                probs, uniform_samples, k, p, deterministic=True))

    # NOTE: CPU-GPU synchronization happens here.
    if not success.all():
238
        if k is not None:
239
            probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
240
        if p is not None:
241
242
243
244
            probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
        next_token_ids = flashinfer.sampling.sampling_from_probs(
            probs, uniform_samples[0], deterministic=True)
    return next_token_ids.view(-1)