rejection_sampler.py 16.5 KB
Newer Older
1
from functools import cached_property
2
from importlib.util import find_spec
3
from typing import Dict, List, Optional, Tuple
4
5
6
7

import torch
import torch.jit

8
9
import vllm.envs as envs
from vllm.logger import init_logger
10
from vllm.model_executor.layers.spec_decode_base_sampler import (
11
    SpecDecodeStochasticBaseSampler)
12

13
14
15
16
17
18
19
20
21
22
23
24
25
logger = init_logger(__name__)

if find_spec("flashinfer"):
    """
    Consider utilizing the FlashInfer rejection sampling kernel initially,
    as it employs a dedicated kernel rather than relying on 
    Torch tensor operations. This design choice helps to fuse operations, 
    reduce memory I/O, and consequently enhances performance.
    """
    from flashinfer.sampling import chain_speculative_sampling
else:
    chain_speculative_sampling = None

26

27
class RejectionSampler(SpecDecodeStochasticBaseSampler):
28
29
30
31
32
    """Apply modified rejection sampling as described in "Accelerating Large
        Language Model Decoding with Speculative Sampling"
        https://arxiv.org/pdf/2302.01318.pdf.
    """

33
34
    def __init__(self,
                 disable_bonus_tokens: bool = True,
35
36
                 strict_mode: bool = False,
                 use_flashinfer: Optional[bool] = None):
37
38
39
        """Create a rejection sampler.

        Args:
40
41
42
            disable_bonus_tokens: Whether or not to disable the bonus token.
            Require when bonus tokens will cause corrupt KV cache for
            proposal methods that require KV cache.
43
            strict_mode: Whether or not to perform shape/device/dtype checks
44
45
            during sampling. This catches correctness issues but adds
            nontrivial latency.
46
47
48
49
            use_falshinfer: We will use this parameter to determine whether
            to use the FlashInfer rejection sampling kernel or not. If it's
            None, we will use the default value from the environment variable.
            This parameter is only used for testing purposes.
50
        """
51
52
        super().__init__(disable_bonus_tokens=disable_bonus_tokens,
                         strict_mode=strict_mode)
53
54
55
56
57
58
59
60
61
62
63
64
        if use_flashinfer is None:
            self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
                chain_speculative_sampling is not None)
        else:
            self.use_flashinfer = use_flashinfer

        if self.use_flashinfer:
            assert not disable_bonus_tokens, \
                "flashinfer will enable bonus token by default"
            logger.info("Use flashinfer for rejection sampling.")
        else:
            logger.info("Use pytorch for rejection sampling.")
65

66
67
    def forward(
        self,
68
        target_with_bonus_probs: torch.Tensor,
69
70
71
        bonus_token_ids: torch.Tensor,
        draft_probs: torch.Tensor,
        draft_token_ids: torch.Tensor,
72
        seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
73
74
75
76
77
78
79
80
81
82
83
84
85
    ) -> torch.Tensor:
        """Sample token ids using rejection sampling. This accepts or rejects
        tokens proposed by the draft model using the probability of each token
        according to the draft and target models.

        In the worst case where all draft tokens are rejected, it is guaranteed
        one correct token will be emitted.

        In the case where all draft tokens are accepted, a bonus token will be
        accepted as its cheap to have the target model score this speculative
        sequence.

        Args:
86
87
88
            target_with_bonus_probs: The probability distribution 
                over token ids given context according to the target model.
            shape = [batch_size, num_speculative_tokens + 1, vocab_size]
89
90
91
92
93
94
95
96
97
98
99
100
101

            bonus_token_ids: The "bonus" token ids that are accepted iff all
                speculative tokens in a sequence are accepted.
            shape = [batch_size, num_bonus_tokens]

            draft_probs: The probability distribution over token ids given
                context according to the draft model.
            shape = [batch_size, num_speculative_tokens, vocab_size]

            draft_token_ids: The token ids that were sampled from the draft
                probabilities.
            shape = [batch_size, num_speculative_tokens]

102
103
104
            seeded_seqs: Dict of batch row index to torch generator, for
                sequences using seeded generation.

105
106
107
108
109
110
111
112
113
        Returns:
            output_token_ids: The token ids sampled via rejection sampling,
                or -1 if unable to sample a token because the previous token
                was rejected.
            shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
        """
        # Only perform shape/dtype/device checking in strict mode, as it adds
        # overhead.
        if self._strict_mode:
114
115
116
            self._raise_if_incorrect_input(target_with_bonus_probs,
                                           draft_token_ids, bonus_token_ids,
                                           draft_probs)
117

118
        batch_size, k, _ = draft_probs.shape
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        # batch_size = 0 when all requests in the batch are
        # non_spec requests. In this case, output_token_ids is
        # just an empty tensor.
        if batch_size == 0:
            return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)

        # If use Flashinfer chain_speculative_sampling kernel
        # for rejection sampling
        if self.use_flashinfer:
            batch_size, k, _ = draft_probs.shape
            uniform_samples = self._create_uniform_samples(
                seeded_seqs, batch_size, k, draft_probs.device)
            output_token_ids, accepted_token_num, emitted_token_num \
                = chain_speculative_sampling(
                draft_probs, draft_token_ids, uniform_samples,
                target_with_bonus_probs)

            # num_emitted_tokens returned by flashinfer
            # does not include the bonus token
            # Flashinfer stops at the first token that violates
            # the condition p >= q and does not include recovery/bonus token.
            # Therefore, we need to add batch_size here.
            self.num_accepted_tokens += accepted_token_num.sum()
            self.num_emitted_tokens += emitted_token_num.sum() + batch_size
            self.num_draft_tokens += batch_size * k
        else:
            accepted, recovered_token_ids = (
                self._batch_modified_rejection_sampling(
                    target_with_bonus_probs[:, :-1],
                    draft_probs,
                    draft_token_ids,
                    seeded_seqs,
                ))

            output_token_ids = self._create_output(
                accepted,
                recovered_token_ids,
                draft_token_ids,
                bonus_token_ids,
            )
160

161
162
163
        return output_token_ids

    def _batch_modified_rejection_sampling(
164
165
166
167
        self,
        target_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_token_ids: torch.Tensor,  # [batch_size, k]
168
        seeded_seqs: Optional[Dict[int, torch.Generator]],
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Perform modified rejection sampling on each sequence.

        Returns:
            A tuple of two tensors:
            0: A bool tensor of which tokens in each sequence is accepted.
                shape = [batch_size, k]
            1: Token ids sampled from a recovered distribution, to be used
                when a token is rejected.
                shape = [batch_size, k]
        """

        batch_size, k, vocab_size = draft_probs.shape

        # shape [batch_size, k]
        accepted = self._get_accepted(target_probs, draft_probs,
185
                                      draft_token_ids, seeded_seqs)
186
187
188
189

        recovered_probs = self._get_recovered_probs(
            target_probs, draft_probs).reshape(batch_size * k, vocab_size)

190
        # NOTE: the recovered_probs are overwritten by this method.
191
192
193
194
        recovered_token_ids = _multinomial(
            recovered_probs,
            num_samples=1,
            k=k,
195
            seeded_seqs=seeded_seqs or {},
196
197
        ).reshape(batch_size, k)

198
199
        return accepted, recovered_token_ids

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    def _create_uniform_samples(self,
                                seeded_seqs: Optional[Dict[int,
                                                           torch.Generator]],
                                batch_size: int, k: int,
                                device: torch.device) -> torch.Tensor:
        """
        Generates a batch of uniform random samples, with optional seeding 
        for specific sequences.

        This method creates a tensor of shape `(batch_size, k + 1)` filled 
        with uniform random values in the range [0, 1). If `seeded_seqs` 
        is provided, the sequences corresponding to specific indices 
        will be generated using the provided `torch.Generator` for 
        reproducibility. The other sequences will be generated without 
        a seed.

        Args:
            seeded_seqs : Optional[Dict[int, torch.Generator]]
                A dictionary mapping indices in the batch to 
                `torch.Generator` objects. If `None`, all samples are 
                generated without a seed.
            batch_size : int
                The number of sequences to generate.
            k : int
                The number of random samples per sequence.
            device : torch.device
                The device on which to allocate the tensor.

        Returns:
            uniform_rand : torch.Tensor
                A tensor of shape `(batch_size, k + 1)` containing uniform 
                random values in the range [0, 1).
        """
        if not seeded_seqs:
            return torch.rand(batch_size, k + 1, device=device)

        uniform_rand = torch.empty(batch_size, k + 1, device=device)

        non_seeded_indices = []
        for idx in range(batch_size):
            generator = seeded_seqs.get(idx)
            if generator is None:
                non_seeded_indices.append(idx)
            else:
                uniform_rand[idx, :] = torch.rand(1,
                                                  k + 1,
                                                  dtype=self.probs_dtype,
                                                  device=device,
                                                  generator=generator)
        if non_seeded_indices:
            uniform_rand[non_seeded_indices, :] = torch.rand(
                len(non_seeded_indices),
                k + 1,
                dtype=self.probs_dtype,
                device=device)
        return uniform_rand

257
    def _get_accepted(
258
259
260
261
        self,
        target_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_token_ids: torch.Tensor,  # [batch_size, k]
262
        seeded_seqs: Optional[Dict[int, torch.Generator]],
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    ) -> torch.Tensor:
        r"""Create bool matrix over the proposed draft tokens. If
        True, then a token can be accepted, else it should be
        rejected.

        Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
        :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
        to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
        same conditional probability according to the draft model, the token
        is accepted with probability:

        .. math::
            \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
                           {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)

        This implementation does not apply causality. When using the output,
        if a token is rejected, subsequent tokens should not be used.

        Returns a bool tensor of shape [batch_size, k] specifying which tokens
        are accepted.
        """
        batch_size, k, _ = draft_probs.shape
        batch_indices = torch.arange(batch_size,
                                     device=target_probs.device)[:, None]
        probs_indicies = torch.arange(k, device=target_probs.device)

        # shape [batch_size, k]
        selected_draft_probs = draft_probs[batch_indices, probs_indicies,
                                           draft_token_ids]

        # shape [batch_size, k]
        selected_target_probs = target_probs[batch_indices, probs_indicies,
                                             draft_token_ids]

297
298
        uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size,
                                                    k - 1, target_probs.device)
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        capped_ratio = torch.minimum(
            selected_target_probs / selected_draft_probs,
            torch.full((1, ), 1, device=target_probs.device))
        accepted = uniform_rand < capped_ratio

        return accepted

    def _get_recovered_probs(
            self,
            target_probs: torch.Tensor,  # [k, vocab_size]
            draft_probs: torch.Tensor,  # [k, vocab_size]
    ) -> torch.Tensor:
        r"""Create a probability distribution for each proposed token which can
        be sampled if the proposed token is rejected.

        When this routine is applied sequentially, the true distribution of the
        target model is recovered (within hardware numerics).

        The probability distribution used in this rejection case is constructed
        as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
        :math:`x` given context :math:`x_1, \dots, x_n` according to the target
        model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
        according to the draft model:

        .. math::
            x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+

        where :math:`(f(x))_+` is defined as:

        .. math::
            (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}

        See https://github.com/vllm-project/vllm/pull/2336 for a visualization
        of the draft, target, and recovered probability distributions.

        Returns a tensor of shape [batch_size, k, vocab_size].

        Note: This batches operations on GPU and thus constructs the recovered
        distribution for all tokens, even if they are accepted. This causes
        division-by-zero errors, so we use self._smallest_positive_value to
        avoid that. This introduces some drift to the distribution.
        """
        _, k, _ = draft_probs.shape

        # shape [batch_size, k, vocab_size]
        difference = target_probs - draft_probs

        # TODO(cade): Can we use logprobs instead of probs, and avoid the
        # division-by-zero errors without introducing distribution drift?

        # shape [batch_size, k, vocab_size]
        f = torch.clamp(difference, min=self._smallest_positive_value)

        # shape [batch_size, k, vocab_size]
        recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)

        return recovered_probs

    @cached_property
    def _smallest_positive_value(self) -> float:
        """Return the smallest positive value representable by the probs dtype.
        This value is used when constructing a distribution from which to sample
        recovered tokens in the first rejection case.

        See _get_recovered_probs for more details

        Note that this isn't actually the smallest positive value representable
        by float32, but the smallest positive normal value.
        See https://en.wikipedia.org/wiki/Subnormal_number for more information.
        """
        return torch.finfo(self.probs_dtype).tiny


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
382
    k: int,
383
    seeded_seqs: Dict[int, torch.Generator],
384
) -> torch.Tensor:
385

386
387
388
389
390
391
    if num_samples > 1:
        # This is equivalent to torch.repeat_interleaved (which also
        # forces a GPU<->CPU sync).
        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
                                         probs.shape[1]).contiguous().view(
                                             -1, probs.shape[1])
392
    q = torch.empty_like(probs)
393
    if not seeded_seqs:
394
395
        q.exponential_(1.0)
    else:
396
397
398
399
400
401
402
403
404
405
406
        non_seeded_indices: List[int] = []
        start = 0
        for idx in range(len(q) // k):
            end = start + k
            generator = seeded_seqs.get(idx)
            if generator is None:
                non_seeded_indices.extend(list(range(start, end)))
            else:
                q[start:end].exponential_(1.0, generator=generator)
            start = end
        q[non_seeded_indices].exponential_(1.0)
407

408
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)