test_rejection_sampler.py 20.1 KB
Newer Older
1
2
3
"""Tests for rejection sampling."""
from typing import List, Tuple

4
import pytest
5
6
7
8
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
9
from vllm.model_executor.utils import set_random_seed
10

11
12
13
14
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

15
16
17
18
19
20
21
22
23
24
25
26
27

def mock_causal_accepted_tensor(
        k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
    """Generate an "accepted" tensor which should yield causally-accepted tokens
    up to last accepted indices.

    Tokens after last_accepted_indices+1 may also be accepted, although they
    will not be causally accepted.
    """
    batch_size = last_accepted_indices.shape[0]

    accepted = (torch.arange(k).expand(batch_size, k) <=
                last_accepted_indices.unsqueeze(-1).broadcast_to(
28
                    batch_size, k))
29
30
31
32
33
34
35

    # Sprinkle accepted values after the contiguous initial accepted values.
    # This replicates the behavior of rejection sampling, which may "accept"
    # a token that cannot be accepted because of causality.
    sprinkle_candidates = (
        torch.arange(k).expand(batch_size, k) >
        last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
36
    sprinkle = torch.rand(batch_size, k) > 0.5
37
38
39
40
41
42
43
44
    accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
    return accepted


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize(
    "which_tokens_accepted",
    ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
45
@pytest.mark.parametrize("device", CUDA_DEVICES)
46
@pytest.mark.parametrize("use_flashinfer", [True, False])
47
@torch.inference_mode()
48
def test_correct_output_format(which_tokens_accepted: str, seed: int,
49
                               device: str, use_flashinfer: bool):
50
51
52
    """Verify the output has correct format given predetermined accepted matrix.
    """
    set_random_seed(seed)
53
    torch.set_default_device(device)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    batch_size = 10
    k = 5
    vocab_size = 3000

    if which_tokens_accepted == "all_tokens_accepted":
        accepted = mock_causal_accepted_tensor(
            k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
    elif which_tokens_accepted == "no_tokens_accepted":
        accepted = mock_causal_accepted_tensor(
            k, -torch.ones((batch_size, ), dtype=torch.long))
    elif which_tokens_accepted == "some_tokens_accepted":
        last_accepted_indices = torch.randint(low=-1,
                                              high=k,
                                              size=(batch_size, ))
        accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
    else:
        raise AssertionError()

    recovered_token_ids = torch.randint(low=0,
                                        high=vocab_size,
                                        size=(batch_size, k),
76
                                        dtype=torch.int64)
77
78
79
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
80
                                    dtype=torch.int64)
81
82
83
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
84
                                    dtype=torch.int64)
85

86
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
87
    rejection_sampler.init_gpu_tensors(device=device)
88
89
90
91
92
93
94
    output_token_ids = rejection_sampler._create_output(  # pylint: disable=protected-access
        accepted,
        recovered_token_ids,
        draft_token_ids,
        bonus_token_ids,
    )

95
    expected_bonus_token_ids = bonus_token_ids.clone()
96

97
98
99
100
101
    if which_tokens_accepted == "all_tokens_accepted":
        # Expect all tokens to be equal to draft tokens.
        assert torch.equal(output_token_ids[:, :-1], draft_token_ids)

        # Expect all bonus tokens to be included.
102
        assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
103
104
105
106
107
108
109
110
111
    elif which_tokens_accepted == "no_tokens_accepted":
        # Expect first token to be equal to recovered tokens.
        assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])

        # Expect everything else to be -1.
        assert torch.equal(output_token_ids[:, 1:],
                           torch.ones_like(output_token_ids[:, 1:]) * -1)
    elif which_tokens_accepted == "some_tokens_accepted":
        recovered_plus_bonus = torch.cat(
112
            (recovered_token_ids, expected_bonus_token_ids), dim=-1)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        # Assert first rejected token is a recovered token or bonus token.
        assert torch.equal(
            recovered_plus_bonus[torch.arange(0, batch_size),
                                 last_accepted_indices + 1],
            output_token_ids[torch.arange(0, batch_size),
                             last_accepted_indices + 1])

        # Assert every subsequent token is -1.
        subsequent_mask = torch.arange(0, k + 1).expand(
            batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
        assert torch.all(output_token_ids[subsequent_mask] == -1)


@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
129
@pytest.mark.parametrize("device", CUDA_DEVICES)
130
@pytest.mark.parametrize("use_flashinfer", [True, False])
131
@torch.inference_mode()
132
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
133
                                    device: str, use_flashinfer: bool):
134
    torch.set_default_device(device)
135
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
136
    rejection_sampler.init_gpu_tensors(device=device)
137

138
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
139
140
141
142
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
143
144
145
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
146
                                    dtype=torch.int64)
147
148
149
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
150
                                    dtype=torch.int64)
151
152

    rejection_sampler(target_probs, bonus_token_ids, draft_probs,
153
                      draft_token_ids)
154
155
156
157
158
159
160
161


@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
162
@pytest.mark.parametrize("use_flashinfer", [True, False])
163
164
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
165
166
                                   frac_seeded: float, n_rep: int, device: str,
                                   use_flashinfer: bool):
167
    torch.set_default_device(device)
168
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
169
    rejection_sampler.init_gpu_tensors(device=device)
170
171

    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
172
173
174
175
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
176
177
178
179
180
181
182
183
184
185
186
187
188
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
189
190
191
192
        seeded_seqs = {
            i: torch.Generator(device=device).manual_seed(i)
            for i in range(batch_size) if seeded_mask[i]
        }
193
194
        results.append(
            rejection_sampler(target_probs, bonus_token_ids, draft_probs,
195
                              draft_token_ids, seeded_seqs))
196
197
198
199
200

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])
201
202


203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
                                       batch_size: int, device: str):
    """
    Test the flashinfer and nonflashinfer backend generate 
    the same output metrics.
    """
    torch.set_default_device(device)
    torch.manual_seed(0)
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    num_accepted_tokens = []
    num_emitted_tokens = []
    num_draft_tokens = []

    def get_seeded_seqs():
        return {
            i: torch.Generator(device=device).manual_seed(i)
            for i in range(batch_size)
        }

    for use_flashinfer in [True, False]:
241
        rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        rejection_sampler.init_gpu_tensors(device=device)
        # We use seeded sequences to ensure the same tokens are accepted
        # for both flashinfer and nonflashinfer backends.
        seeded_seqs = get_seeded_seqs()
        rejection_sampler(target_probs, bonus_token_ids, draft_probs,
                          draft_token_ids, seeded_seqs)
        num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
        num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
        num_draft_tokens.append(rejection_sampler.num_draft_tokens)

    assert num_accepted_tokens[0] == num_accepted_tokens[1]
    assert num_emitted_tokens[0] == num_emitted_tokens[1]
    assert num_draft_tokens[0] == num_draft_tokens[1]


257
258
259
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
                         ["bonus_token_ids", "draft_token_ids"])
260
@pytest.mark.parametrize("device", CUDA_DEVICES)
261
@pytest.mark.parametrize("use_flashinfer", [True, False])
262
263
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
264
265
                               which_token_ids: str, device: str,
                               use_flashinfer: bool):
266
267
268
    k = 3
    batch_size = 5
    vocab_size = 30_000
269
    torch.set_default_device(device)
270

271
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
272
                                         strict_mode=True)
273
    rejection_sampler.init_gpu_tensors(device=device)
274

275
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
276
277
278
279
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
280
281
282
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
283
                                    dtype=torch.int64)
284
285
286
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
287
                                    dtype=torch.int64)
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    oob_token_ids = None
    if which_token_ids == "bonus_token_ids":
        oob_token_ids = bonus_token_ids
    elif which_token_ids == "draft_token_ids":
        oob_token_ids = draft_token_ids
    else:
        raise AssertionError()

    if above_or_below_vocab_range == "above":
        rogue_token_id = vocab_size + 1
    elif above_or_below_vocab_range == "below":
        rogue_token_id = -1
    else:
        raise AssertionError()

    oob_token_ids[0][0] = rogue_token_id

    with pytest.raises(AssertionError):
        rejection_sampler(target_probs, bonus_token_ids, draft_probs,
308
                          draft_token_ids)
309
310
311
312


@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
313
@pytest.mark.parametrize("use_flashinfer", [True, False])
314
315
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
316
        seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.

    When draft_and_target_probs_equal=True, the draft and target
    probabilities are exactly equal. Rejection sampling should
    still work without any NaNs or exceptions.
    """
343
    torch.set_default_device("cpu")
344
345
346
    set_random_seed(seed)
    helper = _CorrectnessTestHelper(
        vocab_size=10,
347
        rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
348
349
350
351
352
353
    )

    draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
        draft_and_target_probs_equal)

    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
354
355
    distance_wrt_reference: List[float] = []
    distance_wrt_target: List[float] = []
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403

    for num_samples in sample_sizes:
        (reference_vs_rejsample_dist,
         target_vs_rejsample_dist) = helper.run_and_compare_distributions(
             draft_probs,
             target_probs,
             reference_probs,
             num_samples,
         )

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
            distance_wrt_target)
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
            distance_wrt_reference)

        print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
              f"{reference_vs_rejsample_dist=:.05f}")
        print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
              f"{relative_change_in_distance_wrt_reference=:.02f}")

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
        distance_wrt_target)
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
        distance_wrt_reference)

    expected_improvement_multiplier = 20
    assert (relative_change_in_distance_wrt_target >
            relative_change_in_distance_wrt_reference *
            expected_improvement_multiplier)


def get_ratio_first_to_last(elements: List[float]) -> float:
    return elements[0] / elements[-1]


class _CorrectnessTestHelper:
    """Class that packages together logic required for the unit-level
    rejection sampling correctness test.
    """

    def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
        self.rejection_sampler = rejection_sampler
        self.vocab_size = vocab_size
        self.vocab_range = (0, vocab_size)

404
        self.rejection_sampler.init_gpu_tensors(device=0)
405
406
407
408
409
410
411
412
413
414
415

        # Keep test simple, use k=1
        self.k = 1

        # Bonus tokens not used, but rejection sampler requires
        # correct shape.
        self.num_bonus_tokens = 1

    def generate_probs_for_test(
        self, draft_and_target_probs_equal: bool
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
416
417
418
419
        draft_probs, target_probs = (F.softmax(
            torch.rand(self.vocab_size, dtype=torch.float32),
            dim=-1,
        ) for _ in range(2))
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

        num_reference_probs = 100
        reference_probs = F.softmax(
            torch.rand(num_reference_probs,
                       self.vocab_size,
                       dtype=torch.float32),
            dim=-1,
        )

        if draft_and_target_probs_equal:
            target_probs = draft_probs.clone()

        return draft_probs, target_probs, reference_probs

    def run_and_compare_distributions(self, draft_probs: torch.Tensor,
                                      target_probs: torch.Tensor,
                                      reference_probs: torch.Tensor,
                                      num_samples: int) -> Tuple[float, float]:
        # Sample using rejection sampling.
        rej_sample_probs = self._estimate_rejection_sampling_pdf(
            draft_probs, target_probs, num_samples)

        # Average distance from reference probs.
        reference_vs_rejsample_dist = torch.dist(
            reference_probs,
            rej_sample_probs).item() / reference_probs.shape[0]
        target_vs_rejsample_dist = torch.dist(target_probs,
                                              rej_sample_probs).item()

        return reference_vs_rejsample_dist, target_vs_rejsample_dist

    def _estimate_rejection_sampling_pdf(
        self,
        draft_probs: torch.Tensor,
        target_probs: torch.Tensor,
        num_samples: int,
    ) -> torch.Tensor:
        # Repeat draft probs num_samples times.
        draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
            num_samples, 1, 1)

461
        # Repeat target probs num_samples * (k + 1) times.
462
463
        # Rejection sampler requires bonus token probs, but they aren't used.
        target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
464
            num_samples, self.k + 1, 1)
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

        # Randomly sample draft token ids from draft probs.
        draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
                                            num_samples=1,
                                            replacement=True).reshape(
                                                num_samples, self.k)

        # Bonus tokens not used but required.
        bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
                                      dtype=torch.int64,
                                      device="cuda").repeat(num_samples, 1)

        # Get output tokens via rejection sampling.
        output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
                                                  bonus_token_ids.to("cuda"),
                                                  draft_probs.to("cuda"),
481
                                                  draft_token_ids.to("cuda"))
482
483
484
485
486
487
488
489
490
491
492
493

        # Remove bonus tokens
        output_token_ids = output_token_ids[:, :-1].flatten()

        # Estimate probability density function
        hist = torch.histogram(output_token_ids.to(dtype=torch.float,
                                                   device="cpu"),
                               bins=self.vocab_size,
                               range=self.vocab_range,
                               density=True)

        return hist.hist