"vllm/vscode:/vscode.git/clone" did not exist on "10398b4706ee71d0bddc32c1d33b11e73df12a27"
rejection_sampler.py 12.7 KB
Newer Older
1
from functools import cached_property
2
from typing import List, Optional, Tuple
3
4
5
6

import torch
import torch.jit

7
from vllm.model_executor.layers.spec_decode_base_sampler import (
8
    SpecDecodeStochasticBaseSampler)
9

10

11
class RejectionSampler(SpecDecodeStochasticBaseSampler):
12
13
14
15
16
    """Apply modified rejection sampling as described in "Accelerating Large
        Language Model Decoding with Speculative Sampling"
        https://arxiv.org/pdf/2302.01318.pdf.
    """

17
18
19
    def __init__(self,
                 disable_bonus_tokens: bool = True,
                 strict_mode: bool = False):
20
21
22
        """Create a rejection sampler.

        Args:
23
24
25
            disable_bonus_tokens: Whether or not to disable the bonus token.
            Require when bonus tokens will cause corrupt KV cache for
            proposal methods that require KV cache.
26
            strict_mode: Whether or not to perform shape/device/dtype checks
27
28
            during sampling. This catches correctness issues but adds
            nontrivial latency.
29
        """
30
31
        super().__init__(disable_bonus_tokens=disable_bonus_tokens,
                         strict_mode=strict_mode)
32

33
34
35
36
37
38
    def forward(
        self,
        target_probs: torch.Tensor,
        bonus_token_ids: torch.Tensor,
        draft_probs: torch.Tensor,
        draft_token_ids: torch.Tensor,
39
        generators: List[Optional[torch.Generator]],
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    ) -> torch.Tensor:
        """Sample token ids using rejection sampling. This accepts or rejects
        tokens proposed by the draft model using the probability of each token
        according to the draft and target models.

        In the worst case where all draft tokens are rejected, it is guaranteed
        one correct token will be emitted.

        In the case where all draft tokens are accepted, a bonus token will be
        accepted as its cheap to have the target model score this speculative
        sequence.

        Args:
            target_probs: The probability distribution over token ids given
                context according to the target model.
            shape = [batch_size, num_speculative_tokens, vocab_size]

            bonus_token_ids: The "bonus" token ids that are accepted iff all
                speculative tokens in a sequence are accepted.
            shape = [batch_size, num_bonus_tokens]

            draft_probs: The probability distribution over token ids given
                context according to the draft model.
            shape = [batch_size, num_speculative_tokens, vocab_size]

            draft_token_ids: The token ids that were sampled from the draft
                probabilities.
            shape = [batch_size, num_speculative_tokens]

        Returns:
            output_token_ids: The token ids sampled via rejection sampling,
                or -1 if unable to sample a token because the previous token
                was rejected.
            shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
        """
        # Only perform shape/dtype/device checking in strict mode, as it adds
        # overhead.
        if self._strict_mode:
78
            self._raise_if_incorrect_input(target_probs, bonus_token_ids,
79
80
                                           draft_probs, draft_token_ids)

81
82
83
84
85
        accepted, recovered_token_ids = (
            self._batch_modified_rejection_sampling(
                target_probs,
                draft_probs,
                draft_token_ids,
86
                generators,
87
            ))
88
89
90
91
92
93
94

        output_token_ids = self._create_output(
            accepted,
            recovered_token_ids,
            draft_token_ids,
            bonus_token_ids,
        )
95

96
97
98
        return output_token_ids

    def _batch_modified_rejection_sampling(
99
100
101
102
103
        self,
        target_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_token_ids: torch.Tensor,  # [batch_size, k]
        generators: List[Optional[torch.Generator]],
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Perform modified rejection sampling on each sequence.

        Returns:
            A tuple of two tensors:
            0: A bool tensor of which tokens in each sequence is accepted.
                shape = [batch_size, k]
            1: Token ids sampled from a recovered distribution, to be used
                when a token is rejected.
                shape = [batch_size, k]
        """

        batch_size, k, vocab_size = draft_probs.shape

        # shape [batch_size, k]
        accepted = self._get_accepted(target_probs, draft_probs,
120
                                      draft_token_ids, generators)
121
122
123
124

        recovered_probs = self._get_recovered_probs(
            target_probs, draft_probs).reshape(batch_size * k, vocab_size)

125
126
127
        seed_indices, non_seed_indices = self._split_batch_by_seeded(
            generators, k=k)

128
        # NOTE: the recovered_probs are overwritten by this method.
129
130
131
132
133
134
135
136
137
138
        recovered_token_ids = _multinomial(
            recovered_probs,
            num_samples=1,
            k=k,
            generators=generators,
            seed_indices=seed_indices,
            # this arg is unused when None but torch.jit requires a list
            non_seed_indices=non_seed_indices or [],
        ).reshape(batch_size, k)

139
140
141
        return accepted, recovered_token_ids

    def _get_accepted(
142
143
144
145
146
        self,
        target_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_probs: torch.Tensor,  # [batch_size, k, vocab_size]
        draft_token_ids: torch.Tensor,  # [batch_size, k]
        generators: List[Optional[torch.Generator]],
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    ) -> torch.Tensor:
        r"""Create bool matrix over the proposed draft tokens. If
        True, then a token can be accepted, else it should be
        rejected.

        Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
        :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
        to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
        same conditional probability according to the draft model, the token
        is accepted with probability:

        .. math::
            \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
                           {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)

        This implementation does not apply causality. When using the output,
        if a token is rejected, subsequent tokens should not be used.

        Returns a bool tensor of shape [batch_size, k] specifying which tokens
        are accepted.
        """
        batch_size, k, _ = draft_probs.shape
        batch_indices = torch.arange(batch_size,
                                     device=target_probs.device)[:, None]
        probs_indicies = torch.arange(k, device=target_probs.device)

        # shape [batch_size, k]
        selected_draft_probs = draft_probs[batch_indices, probs_indicies,
                                           draft_token_ids]

        # shape [batch_size, k]
        selected_target_probs = target_probs[batch_indices, probs_indicies,
                                             draft_token_ids]

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        seed_indices, non_seed_indices = self._split_batch_by_seeded(
            generators)

        if len(seed_indices) == 0:
            uniform_rand = torch.rand_like(selected_target_probs)
        else:
            uniform_rand = torch.empty_like(selected_target_probs)

            for idx in seed_indices:
                uniform_rand[idx, :] = torch.rand(1,
                                                  k,
                                                  dtype=self.probs_dtype,
                                                  device=target_probs.device,
                                                  generator=generators[idx])

            if non_seed_indices:
                uniform_rand[non_seed_indices, :] = torch.rand(
                    len(non_seed_indices),
                    k,
                    dtype=self.probs_dtype,
                    device=target_probs.device)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        capped_ratio = torch.minimum(
            selected_target_probs / selected_draft_probs,
            torch.full((1, ), 1, device=target_probs.device))
        accepted = uniform_rand < capped_ratio

        return accepted

    def _get_recovered_probs(
            self,
            target_probs: torch.Tensor,  # [k, vocab_size]
            draft_probs: torch.Tensor,  # [k, vocab_size]
    ) -> torch.Tensor:
        r"""Create a probability distribution for each proposed token which can
        be sampled if the proposed token is rejected.

        When this routine is applied sequentially, the true distribution of the
        target model is recovered (within hardware numerics).

        The probability distribution used in this rejection case is constructed
        as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
        :math:`x` given context :math:`x_1, \dots, x_n` according to the target
        model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
        according to the draft model:

        .. math::
            x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+

        where :math:`(f(x))_+` is defined as:

        .. math::
            (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}

        See https://github.com/vllm-project/vllm/pull/2336 for a visualization
        of the draft, target, and recovered probability distributions.

        Returns a tensor of shape [batch_size, k, vocab_size].

        Note: This batches operations on GPU and thus constructs the recovered
        distribution for all tokens, even if they are accepted. This causes
        division-by-zero errors, so we use self._smallest_positive_value to
        avoid that. This introduces some drift to the distribution.
        """
        _, k, _ = draft_probs.shape

        # shape [batch_size, k, vocab_size]
        difference = target_probs - draft_probs

        # TODO(cade): Can we use logprobs instead of probs, and avoid the
        # division-by-zero errors without introducing distribution drift?

        # shape [batch_size, k, vocab_size]
        f = torch.clamp(difference, min=self._smallest_positive_value)

        # shape [batch_size, k, vocab_size]
        recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)

        return recovered_probs

    @cached_property
    def _smallest_positive_value(self) -> float:
        """Return the smallest positive value representable by the probs dtype.
        This value is used when constructing a distribution from which to sample
        recovered tokens in the first rejection case.

        See _get_recovered_probs for more details

        Note that this isn't actually the smallest positive value representable
        by float32, but the smallest positive normal value.
        See https://en.wikipedia.org/wiki/Subnormal_number for more information.
        """
        return torch.finfo(self.probs_dtype).tiny

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    # partition batch into indices for which a generator is provided
    # and indicies for which no generator is provided
    @staticmethod
    def _split_batch_by_seeded(
        generators: List[Optional[torch.Generator]],
        k: int = 1,
    ) -> Tuple[List[int], Optional[List[int]]]:

        if all(generator is None for generator in generators):
            seed_indices: List[int] = []
            non_seed_indices: Optional[List[int]] = None
        else:
            seed_indices, non_seed_indices = [], []
            for i, generator in enumerate(generators):
                if generator is None:
                    non_seed_indices.extend(range(k * i, k * (i + 1)))
                else:
                    seed_indices.extend(range(k * i, k * (i + 1)))

        return seed_indices, non_seed_indices

296
297
298
299
300
301
302
303
304
305

# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
306
307
308
309
    k: int,
    generators: List[Optional[torch.Generator]],
    seed_indices: List[int],
    non_seed_indices: List[int],
310
) -> torch.Tensor:
311

312
313
314
315
316
317
    if num_samples > 1:
        # This is equivalent to torch.repeat_interleaved (which also
        # forces a GPU<->CPU sync).
        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
                                         probs.shape[1]).contiguous().view(
                                             -1, probs.shape[1])
318
319
320
321
322
323
324
325
326

    q = torch.empty_like(probs)
    if len(seed_indices) == 0:
        q.exponential_(1.0)
    else:
        q[non_seed_indices].exponential_(1.0)
        for idx in seed_indices:
            q[idx].exponential_(1.0, generator=generators[idx // k])

327
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)