test_rejection_sampler.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional
4
5
6

import pytest
import torch
7
import torch.nn.functional as F
8

9
from vllm.platforms import current_platform
10
from vllm.v1.sample.logits_processor import LogitsProcessors
11
from vllm.v1.sample.metadata import SamplingMetadata
12
13
14
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
                                              RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
15

16
DEVICE = current_platform.device_type
17

18
19

@pytest.fixture
20
def rejection_sampler():
21
22
23
    return RejectionSampler()


24
def create_logits_tensor(output_token_ids: list[list[int]],
25
                         vocab_size: int = 100) -> torch.Tensor:
26
    """Helper function to create logits tensor that
27
       will produce desired token ids on argmax"""
28
    token_ids = [tokens[:-1] for tokens in output_token_ids]
29
30
31
32
33
34
35
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
36
37
38
    return logits


39
def create_sampling_metadata(
40
41
    all_greedy: bool,
    temperature: Optional[torch.Tensor] = None,
42
43
    top_k: Optional[torch.Tensor] = None,
    top_p: Optional[torch.Tensor] = None,
44
45
    generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
46
47
    """Create a v1 sampling metadata object with all_greedy set
        to the given value. Either all greedy or all random sampling
48
49
50
        is used.
    """
    generators = generators or {}
51
52
53
54
55
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

56
    return SamplingMetadata(
57
        temperature=temperature,
58
59
        all_greedy=all_greedy,
        all_random=not all_greedy,
60
61
        top_p=top_p,
        top_k=top_k,
62
        generators=generators,
63
64
65
66
67
68
69
        max_num_logprobs=0,
        no_penalties=False,
        prompt_token_ids=None,
        frequency_penalties=torch.tensor([]),
        presence_penalties=torch.tensor([]),
        repetition_penalties=torch.tensor([]),
        output_token_ids=[],
70
        allowed_token_ids_mask=None,
71
        bad_words_token_ids={},
72
        logitsprocs=LogitsProcessors(),
73
74
75
    )


76
########################### Tests for Greedy Sampling ###################
77
def test_perfect_match(rejection_sampler):
78
79
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
80
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
81

82
    metadata = create_sampling_metadata(all_greedy=True)
83
    logits = create_logits_tensor(output_tokens)
84
85
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
86
87
88
89
90
91
92
93
94
95
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
96
97
98
    expected = torch.tensor([[1, 2, 3, 4]],
                            dtype=torch.int,
                            device=logits.device)
99
    assert torch.equal(output, expected)
100
101


102
def test_early_mismatch(rejection_sampler):
103
104
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
105
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
106

107
    metadata = create_sampling_metadata(all_greedy=True)
108
    logits = create_logits_tensor(output_tokens)
109
110
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
126
    assert torch.equal(output, expected)
127
128


129
def test_multiple_sequences(rejection_sampler):
130
131
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
132
133
    output_tokens = [[1, 2, 5], [3,
                                 4]]  # Two sequences with bonus tokens 5 and 4
134

135
    metadata = create_sampling_metadata(all_greedy=True)
136
    logits = create_logits_tensor(output_tokens)
137
138
    bonus_token_tensor = torch.tensor(
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
139
140
141
142
143
144
145
146
147
148
149
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
150
151
                            dtype=torch.int,
                            device=logits.device)
152
    assert torch.equal(output, expected)
153
154


155
def test_single_token_sequence(rejection_sampler):
156
157
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
158
    output_tokens = [[1, 2]]  # Single token with bonus token 2
159

160
    metadata = create_sampling_metadata(all_greedy=True)
161
    logits = create_logits_tensor(output_tokens)
162
163
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
164
165
166
167
168
169
170
171
172
173
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
174
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
175
    assert torch.equal(output, expected)
176
177


178
def test_empty_sequence(rejection_sampler):
179
    """Test handling empty sequence of speculated tokens"""
180
    spec_tokens: list[list[int]] = [[]]
181
    output_tokens = [[5]]  # Just the bonus token
182

183
    metadata = create_sampling_metadata(all_greedy=True)
184
    logits = create_logits_tensor(output_tokens)
185
186
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
187
188
189
190
191
192
193
194
195
196
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
197
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
198
    assert torch.equal(output, expected)
199
200


201
def test_multiple_mismatches(rejection_sampler):
202
203
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
204
205
    output_tokens = [[1, 2, 7, 6], [4, 8, 6,
                                    9]]  # Mismatches in both sequences
206

207
    metadata = create_sampling_metadata(all_greedy=True)
208
    logits = create_logits_tensor(output_tokens)
209
210
    bonus_token_tensor = torch.tensor(
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 2, 7, PLACEHOLDER_TOKEN_ID],
         [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
227
    assert torch.equal(output, expected)
228
229
230
231
232


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
233
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
234
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
235
        ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
236
         [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]),  # Mixed matches
237
    ])
238
239
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
                            expected):
240
    """Parametrized test for various matching scenarios"""
241
    metadata = create_sampling_metadata(all_greedy=True)
242
    logits = create_logits_tensor(output_tokens)
243
244
    bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
                                      device=logits.device)
245
246
247
248
249
250
251
252
253
254
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
255
256
257
    expected_tensor = torch.tensor(expected,
                                   dtype=torch.int,
                                   device=logits.device)
258
259
260
261
262
263
264
265
266
    assert torch.equal(output, expected_tensor)


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
    draft_probs = torch.rand(num_tokens,
                             vocab_size,
                             dtype=torch.float32,
                             device=DEVICE)
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
282
283
284
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
285
286
                                    dtype=torch.int64,
                                    device=DEVICE)
287
288
289
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
290
291
                                    dtype=torch.int64,
                                    device=DEVICE)
292
293
294
295
296
297
298
299
300
301

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
            i: torch.Generator(device=DEVICE).manual_seed(i)
            for i in range(batch_size) if seeded_mask[i]
        }

302
303
304
        temperature = torch.ones(batch_size,
                                 dtype=torch.float32,
                                 device=DEVICE)
305
        sampling_metadata = create_sampling_metadata(all_greedy=False,
306
                                                     temperature=temperature,
307
                                                     generators=seeded_seqs)
308
309
310
311
312
313
314
315
316
        spec_decode_metadata = SpecDecodeMetadata.make_dummy(
            draft_token_ids.tolist(), device=DEVICE)
        rep_result = rejection_sampler(
            spec_decode_metadata,
            draft_probs=draft_probs,
            target_logits=target_logits,
            bonus_token_ids=bonus_token_ids,
            sampling_metadata=sampling_metadata,
        )
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

        results.append(rep_result)

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
    torch.set_default_device(DEVICE)
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
355
356
357
358
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
                            dim=-1)
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
359
360
361
362
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
363

364
365
366
367
368
369
370
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
371
            draft_probs, target_logits, k, vocab_size, num_samples)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        rej_sample_probs = rej_sample_probs.to(DEVICE)

        # Average distance from reference probs.
        reference_vs_rejsample_dist = torch.dist(
            reference_probs,
            rej_sample_probs).item() / reference_probs.shape[0]
        target_vs_rejsample_dist = torch.dist(target_probs,
                                              rej_sample_probs).item()

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
            distance_wrt_target)
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
            distance_wrt_reference)

        print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
              f"{reference_vs_rejsample_dist=:.05f}")
        print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
              f"{relative_change_in_distance_wrt_reference=:.02f}")

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
        distance_wrt_target)
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
        distance_wrt_reference)

    expected_improvement_multiplier = 20
    assert (relative_change_in_distance_wrt_target
            > relative_change_in_distance_wrt_reference *
            expected_improvement_multiplier)


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
411
    target_logits: torch.Tensor,
412
413
414
415
416
417
418
419
420
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
421
        target_logits: Target logits.
422
423
424
425
426
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
427
428
429
    rejection_sampler = RejectionSampler()
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
430
431
432
    draft_probs = draft_probs.reshape(1, 1,
                                      vocab_size).repeat(num_samples, k, 1)

433
434
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
435
436
437
438
439
440

    # Randomly sample draft token ids from draft probs.
    draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
                                        num_samples=k,
                                        replacement=True).reshape(
                                            num_samples, k)
441
    draft_probs = draft_probs.view(num_tokens, vocab_size)
442
443
444
445
446

    # Bonus tokens not used but required.
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
                                  device=DEVICE).repeat(num_samples, 1)

447
448
449
450
451
452
453
454
455
456
457
458
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(all_greedy=False,
                                                 temperature=temperature)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids.tolist(), device=bonus_token_ids.device)
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )
459
460
461
462
463
464
465
466
467
    output_token_ids = output_token_ids[:, :-1].flatten()

    hist = torch.histogram(output_token_ids.to(dtype=torch.float,
                                               device="cpu"),
                           bins=vocab_size,
                           range=(0, vocab_size),
                           density=True)

    return hist.hist
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611


def _test_masked_logits(
    rejection_sampler,
    batch_size: int,
    num_draft_tokens: int,
    vocab_size: int,
    target_logits: torch.Tensor,
    unmasked_indices: torch.Tensor,
    sampling_metadata: SamplingMetadata,
):
    # Set up test parameters
    num_tokens = batch_size * num_draft_tokens

    # Create random draft probabilities.
    draft_probs = torch.rand((num_tokens, vocab_size),
                             dtype=torch.float32,
                             device=DEVICE)
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
    draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
    draft_token_ids = draft_token_ids.tolist()

    # Bonus tokens not used but required
    bonus_token_ids = torch.zeros((batch_size, 1),
                                  dtype=torch.int64,
                                  device=DEVICE)

    # Create spec decode metadata
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids,
        device=DEVICE,
    )

    # Run rejection sampling
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )

    # Remove bonus tokens and reshape
    output_token_ids = output_token_ids[:, :-1].flatten().tolist()

    # Check that all sampled tokens are within the unmasked indices.
    for i in range(num_tokens):
        token_id = output_token_ids[i]
        if token_id == PLACEHOLDER_TOKEN_ID:
            continue
        assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
    """Test rejection sampling with top-k sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Randomly create top-k indices.
    top_k_indices = [
        torch.randperm(vocab_size, device=DEVICE)[:top_k]
        for _ in range(num_tokens)
    ]
    top_k_indices = torch.stack(top_k_indices)

    # Create logits with the uniform distribution.
    target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)

    # Increment the logits for top-k indices, a little bit more than the other
    # ones. If the masking is effective, the non-topk indices will never be
    # sampled despite the small difference in logits.
    for i in range(num_tokens):
        target_logits[i, top_k_indices[i]] += 0.1

    # Create sampling metadata
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
        top_k=torch.tensor([top_k] * batch_size,
                           device=DEVICE,
                           dtype=torch.int64),
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_k_indices,
        sampling_metadata=sampling_metadata,
    )


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
    """Test rejection sampling with top-p sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Create logits with the uniform distribution.
    target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    rescaled_logits = target_logits / temperature

    logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - top_p
    # at least one
    top_p_mask[:, -1] = False

    # Get the top-p indices.
    top_p_indices = []
    for i in range(num_tokens):
        top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

    # Create sampling metadata
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
        top_p=torch.tensor([top_p] * batch_size,
                           device=DEVICE,
                           dtype=torch.float32),
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_p_indices,
        sampling_metadata=sampling_metadata,
    )