test_rejection_sampler.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Tests for rejection sampling."""

4
import pytest
5
6
7
8
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
9
from vllm.model_executor.utils import set_random_seed
10

11
12
13
14
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

15
16
17
18
19
20
21
22
23
24
25

def mock_causal_accepted_tensor(
        k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
    """Generate an "accepted" tensor which should yield causally-accepted tokens
    up to last accepted indices.

    Tokens after last_accepted_indices+1 may also be accepted, although they
    will not be causally accepted.
    """
    batch_size = last_accepted_indices.shape[0]

26
27
    accepted = (torch.arange(k).expand(batch_size, k)
                <= last_accepted_indices.unsqueeze(-1).broadcast_to(
28
                    batch_size, k))
29
30
31
32

    # Sprinkle accepted values after the contiguous initial accepted values.
    # This replicates the behavior of rejection sampling, which may "accept"
    # a token that cannot be accepted because of causality.
33
34
35
36
    sprinkle_candidates = (torch.arange(k).expand(
        batch_size,
        k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) +
                           1)
37
    sprinkle = torch.rand(batch_size, k) > 0.5
38
39
40
41
42
43
44
45
    accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
    return accepted


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize(
    "which_tokens_accepted",
    ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
46
@pytest.mark.parametrize("device", CUDA_DEVICES)
47
@pytest.mark.parametrize("use_flashinfer", [True, False])
48
@torch.inference_mode()
49
def test_correct_output_format(which_tokens_accepted: str, seed: int,
50
                               device: str, use_flashinfer: bool):
51
52
53
    """Verify the output has correct format given predetermined accepted matrix.
    """
    set_random_seed(seed)
54
    torch.set_default_device(device)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    batch_size = 10
    k = 5
    vocab_size = 3000

    if which_tokens_accepted == "all_tokens_accepted":
        accepted = mock_causal_accepted_tensor(
            k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
    elif which_tokens_accepted == "no_tokens_accepted":
        accepted = mock_causal_accepted_tensor(
            k, -torch.ones((batch_size, ), dtype=torch.long))
    elif which_tokens_accepted == "some_tokens_accepted":
        last_accepted_indices = torch.randint(low=-1,
                                              high=k,
                                              size=(batch_size, ))
        accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
    else:
        raise AssertionError()

    recovered_token_ids = torch.randint(low=0,
                                        high=vocab_size,
                                        size=(batch_size, k),
77
                                        dtype=torch.int64)
78
79
80
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
81
                                    dtype=torch.int64)
82
83
84
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
85
                                    dtype=torch.int64)
86

87
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
88
    rejection_sampler.init_gpu_tensors(device=device)
89
90
91
92
93
94
95
    output_token_ids = rejection_sampler._create_output(  # pylint: disable=protected-access
        accepted,
        recovered_token_ids,
        draft_token_ids,
        bonus_token_ids,
    )

96
    expected_bonus_token_ids = bonus_token_ids.clone()
97

98
99
100
101
102
    if which_tokens_accepted == "all_tokens_accepted":
        # Expect all tokens to be equal to draft tokens.
        assert torch.equal(output_token_ids[:, :-1], draft_token_ids)

        # Expect all bonus tokens to be included.
103
        assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
104
105
106
107
108
109
110
111
112
    elif which_tokens_accepted == "no_tokens_accepted":
        # Expect first token to be equal to recovered tokens.
        assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])

        # Expect everything else to be -1.
        assert torch.equal(output_token_ids[:, 1:],
                           torch.ones_like(output_token_ids[:, 1:]) * -1)
    elif which_tokens_accepted == "some_tokens_accepted":
        recovered_plus_bonus = torch.cat(
113
            (recovered_token_ids, expected_bonus_token_ids), dim=-1)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        # Assert first rejected token is a recovered token or bonus token.
        assert torch.equal(
            recovered_plus_bonus[torch.arange(0, batch_size),
                                 last_accepted_indices + 1],
            output_token_ids[torch.arange(0, batch_size),
                             last_accepted_indices + 1])

        # Assert every subsequent token is -1.
        subsequent_mask = torch.arange(0, k + 1).expand(
            batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
        assert torch.all(output_token_ids[subsequent_mask] == -1)


@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
130
@pytest.mark.parametrize("device", CUDA_DEVICES)
131
@pytest.mark.parametrize("use_flashinfer", [True, False])
132
@torch.inference_mode()
133
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
134
                                    device: str, use_flashinfer: bool):
135
    torch.set_default_device(device)
136
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
137
    rejection_sampler.init_gpu_tensors(device=device)
138

139
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
140
141
142
143
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
144
145
146
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
147
                                    dtype=torch.int64)
148
149
150
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
151
                                    dtype=torch.int64)
152
153

    rejection_sampler(target_probs, bonus_token_ids, draft_probs,
154
                      draft_token_ids)
155
156
157
158
159
160
161
162


@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
163
@pytest.mark.parametrize("use_flashinfer", [True, False])
164
165
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
166
167
                                   frac_seeded: float, n_rep: int, device: str,
                                   use_flashinfer: bool):
168
    torch.set_default_device(device)
169
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
170
    rejection_sampler.init_gpu_tensors(device=device)
171
172

    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
173
174
175
176
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
177
178
179
180
181
182
183
184
185
186
187
188
189
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
190
191
192
193
        seeded_seqs = {
            i: torch.Generator(device=device).manual_seed(i)
            for i in range(batch_size) if seeded_mask[i]
        }
194
195
        results.append(
            rejection_sampler(target_probs, bonus_token_ids, draft_probs,
196
                              draft_token_ids, seeded_seqs))
197
198
199
200
201

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])
202
203


204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
                            device: str, use_flashinfer: bool):
    torch.set_default_device(device)
    set_random_seed(0)
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    single_batches = []
    for i in range(batch_size):
        single_batches.append((draft_probs[i].clone().unsqueeze(0),
                               draft_token_ids[i].clone().unsqueeze(0),
                               target_probs[i].clone().unsqueeze(0),
                               bonus_token_ids[i].clone().unsqueeze(0),
                               draft_token_ids[i].clone().unsqueeze(0)))

    set_random_seed(0)
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
    rejection_sampler.init_gpu_tensors(device=device)

    results = []
    seeded_seqs = {
        i: torch.Generator(device=device).manual_seed(i)
        for i in range(1, batch_size)  # 0 is seed None
    }
    batch_result = rejection_sampler(target_probs.clone(),
                                     bonus_token_ids.clone(),
                                     draft_probs.clone(),
                                     draft_token_ids.clone(), seeded_seqs)

    set_random_seed(0)

    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
    rejection_sampler.init_gpu_tensors(device=device)
    for i in range(batch_size):
        request_seeded_seqs = {
            0: torch.Generator(device=device).manual_seed(i)
        } if seeded_seqs.get(i) is not None else None
        (draft_probs, draft_token_ids, target_probs, bonus_token_ids,
         draft_token_ids) = single_batches[i]
        results.append(
            rejection_sampler(target_probs, bonus_token_ids, draft_probs,
                              draft_token_ids, request_seeded_seqs))
    for i in range(batch_size):
        assert torch.equal(batch_result[i], results[i].squeeze(0))


267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
                                       batch_size: int, device: str):
    """
    Test the flashinfer and nonflashinfer backend generate 
    the same output metrics.
    """
    torch.set_default_device(device)
    torch.manual_seed(0)
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    num_accepted_tokens = []
    num_emitted_tokens = []
    num_draft_tokens = []

    def get_seeded_seqs():
        return {
            i: torch.Generator(device=device).manual_seed(i)
            for i in range(batch_size)
        }

    for use_flashinfer in [True, False]:
305
        rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        rejection_sampler.init_gpu_tensors(device=device)
        # We use seeded sequences to ensure the same tokens are accepted
        # for both flashinfer and nonflashinfer backends.
        seeded_seqs = get_seeded_seqs()
        rejection_sampler(target_probs, bonus_token_ids, draft_probs,
                          draft_token_ids, seeded_seqs)
        num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
        num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
        num_draft_tokens.append(rejection_sampler.num_draft_tokens)

    assert num_accepted_tokens[0] == num_accepted_tokens[1]
    assert num_emitted_tokens[0] == num_emitted_tokens[1]
    assert num_draft_tokens[0] == num_draft_tokens[1]


321
322
323
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
                         ["bonus_token_ids", "draft_token_ids"])
324
@pytest.mark.parametrize("device", CUDA_DEVICES)
325
@pytest.mark.parametrize("use_flashinfer", [True, False])
326
327
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
328
329
                               which_token_ids: str, device: str,
                               use_flashinfer: bool):
330
331
332
    k = 3
    batch_size = 5
    vocab_size = 30_000
333
    torch.set_default_device(device)
334

335
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
336
                                         strict_mode=True)
337
    rejection_sampler.init_gpu_tensors(device=device)
338

339
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
340
341
342
343
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
344
345
346
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
347
                                    dtype=torch.int64)
348
349
350
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
351
                                    dtype=torch.int64)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

    oob_token_ids = None
    if which_token_ids == "bonus_token_ids":
        oob_token_ids = bonus_token_ids
    elif which_token_ids == "draft_token_ids":
        oob_token_ids = draft_token_ids
    else:
        raise AssertionError()

    if above_or_below_vocab_range == "above":
        rogue_token_id = vocab_size + 1
    elif above_or_below_vocab_range == "below":
        rogue_token_id = -1
    else:
        raise AssertionError()

    oob_token_ids[0][0] = rogue_token_id

    with pytest.raises(AssertionError):
        rejection_sampler(target_probs, bonus_token_ids, draft_probs,
372
                          draft_token_ids)
373
374
375
376


@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
377
@pytest.mark.parametrize("use_flashinfer", [True, False])
378
379
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
380
        seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.

    When draft_and_target_probs_equal=True, the draft and target
    probabilities are exactly equal. Rejection sampling should
    still work without any NaNs or exceptions.
    """
407
    torch.set_default_device("cpu")
408
409
410
    set_random_seed(seed)
    helper = _CorrectnessTestHelper(
        vocab_size=10,
411
        rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
412
413
414
415
416
417
    )

    draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
        draft_and_target_probs_equal)

    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
418
419
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    for num_samples in sample_sizes:
        (reference_vs_rejsample_dist,
         target_vs_rejsample_dist) = helper.run_and_compare_distributions(
             draft_probs,
             target_probs,
             reference_probs,
             num_samples,
         )

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
            distance_wrt_target)
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
            distance_wrt_reference)

        print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
              f"{reference_vs_rejsample_dist=:.05f}")
        print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
              f"{relative_change_in_distance_wrt_reference=:.02f}")

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
        distance_wrt_target)
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
        distance_wrt_reference)

    expected_improvement_multiplier = 20
449
450
    assert (relative_change_in_distance_wrt_target
            > relative_change_in_distance_wrt_reference *
451
452
453
            expected_improvement_multiplier)


454
def get_ratio_first_to_last(elements: list[float]) -> float:
455
456
457
458
459
460
461
462
463
464
465
466
467
    return elements[0] / elements[-1]


class _CorrectnessTestHelper:
    """Class that packages together logic required for the unit-level
    rejection sampling correctness test.
    """

    def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
        self.rejection_sampler = rejection_sampler
        self.vocab_size = vocab_size
        self.vocab_range = (0, vocab_size)

468
        self.rejection_sampler.init_gpu_tensors(device=0)
469
470
471
472
473
474
475
476
477
478

        # Keep test simple, use k=1
        self.k = 1

        # Bonus tokens not used, but rejection sampler requires
        # correct shape.
        self.num_bonus_tokens = 1

    def generate_probs_for_test(
        self, draft_and_target_probs_equal: bool
479
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
480
481
482
483
        draft_probs, target_probs = (F.softmax(
            torch.rand(self.vocab_size, dtype=torch.float32),
            dim=-1,
        ) for _ in range(2))
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

        num_reference_probs = 100
        reference_probs = F.softmax(
            torch.rand(num_reference_probs,
                       self.vocab_size,
                       dtype=torch.float32),
            dim=-1,
        )

        if draft_and_target_probs_equal:
            target_probs = draft_probs.clone()

        return draft_probs, target_probs, reference_probs

    def run_and_compare_distributions(self, draft_probs: torch.Tensor,
                                      target_probs: torch.Tensor,
                                      reference_probs: torch.Tensor,
501
                                      num_samples: int) -> tuple[float, float]:
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        # Sample using rejection sampling.
        rej_sample_probs = self._estimate_rejection_sampling_pdf(
            draft_probs, target_probs, num_samples)

        # Average distance from reference probs.
        reference_vs_rejsample_dist = torch.dist(
            reference_probs,
            rej_sample_probs).item() / reference_probs.shape[0]
        target_vs_rejsample_dist = torch.dist(target_probs,
                                              rej_sample_probs).item()

        return reference_vs_rejsample_dist, target_vs_rejsample_dist

    def _estimate_rejection_sampling_pdf(
        self,
        draft_probs: torch.Tensor,
        target_probs: torch.Tensor,
        num_samples: int,
    ) -> torch.Tensor:
        # Repeat draft probs num_samples times.
        draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
            num_samples, 1, 1)

525
        # Repeat target probs num_samples * (k + 1) times.
526
527
        # Rejection sampler requires bonus token probs, but they aren't used.
        target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
528
            num_samples, self.k + 1, 1)
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544

        # Randomly sample draft token ids from draft probs.
        draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
                                            num_samples=1,
                                            replacement=True).reshape(
                                                num_samples, self.k)

        # Bonus tokens not used but required.
        bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
                                      dtype=torch.int64,
                                      device="cuda").repeat(num_samples, 1)

        # Get output tokens via rejection sampling.
        output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
                                                  bonus_token_ids.to("cuda"),
                                                  draft_probs.to("cuda"),
545
                                                  draft_token_ids.to("cuda"))
546
547
548
549
550
551
552
553
554
555
556
557

        # Remove bonus tokens
        output_token_ids = output_token_ids[:, :-1].flatten()

        # Estimate probability density function
        hist = torch.histogram(output_token_ids.to(dtype=torch.float,
                                                   device="cpu"),
                               bins=self.vocab_size,
                               range=self.vocab_range,
                               density=True)

        return hist.hist