test_rejection_sampler.py 20.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional
4
5
6

import pytest
import torch
7
import torch.nn.functional as F
8

9
from vllm.platforms import current_platform
10
from vllm.v1.sample.logits_processor import LogitsProcessors
11
from vllm.v1.sample.metadata import SamplingMetadata
12
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
13
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
14

15
DEVICE = current_platform.device_type
16

17
18

@pytest.fixture
19
def rejection_sampler():
20
21
22
    return RejectionSampler()


23
24
25
def create_logits_tensor(
    output_token_ids: list[list[int]], vocab_size: int = 100
) -> torch.Tensor:
26
    """Helper function to create logits tensor that
27
    will produce desired token ids on argmax"""
28
    token_ids = [tokens[:-1] for tokens in output_token_ids]
29
30
31
32
33
34
35
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
36
37
38
    return logits


39
def create_sampling_metadata(
40
41
    all_greedy: bool,
    temperature: Optional[torch.Tensor] = None,
42
43
    top_k: Optional[torch.Tensor] = None,
    top_p: Optional[torch.Tensor] = None,
44
45
    generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
46
    """Create a v1 sampling metadata object with all_greedy set
47
48
    to the given value. Either all greedy or all random sampling
    is used.
49
50
    """
    generators = generators or {}
51
52
53
54
55
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

56
    return SamplingMetadata(
57
        temperature=temperature,
58
59
        all_greedy=all_greedy,
        all_random=not all_greedy,
60
61
        top_p=top_p,
        top_k=top_k,
62
        generators=generators,
63
64
65
66
67
68
69
        max_num_logprobs=0,
        no_penalties=False,
        prompt_token_ids=None,
        frequency_penalties=torch.tensor([]),
        presence_penalties=torch.tensor([]),
        repetition_penalties=torch.tensor([]),
        output_token_ids=[],
70
        allowed_token_ids_mask=None,
71
        bad_words_token_ids={},
72
        logitsprocs=LogitsProcessors(),
73
74
75
    )


76
########################### Tests for Greedy Sampling ###################
77
def test_perfect_match(rejection_sampler):
78
79
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
80
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
81

82
    metadata = create_sampling_metadata(all_greedy=True)
83
    logits = create_logits_tensor(output_tokens)
84
85
86
87
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
88
89
90
91
92
93
94
95

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
96
    expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
97
    assert torch.equal(output, expected)
98
99


100
def test_early_mismatch(rejection_sampler):
101
102
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
103
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
104

105
    metadata = create_sampling_metadata(all_greedy=True)
106
    logits = create_logits_tensor(output_tokens)
107
108
109
110
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
111
112
113
114
115
116
117
118
119
120
121
122
123

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
124
    assert torch.equal(output, expected)
125
126


127
def test_multiple_sequences(rejection_sampler):
128
129
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
130
    output_tokens = [[1, 2, 5], [3, 4]]  # Two sequences with bonus tokens 5 and 4
131

132
    metadata = create_sampling_metadata(all_greedy=True)
133
    logits = create_logits_tensor(output_tokens)
134
    bonus_token_tensor = torch.tensor(
135
136
137
138
139
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
140
141
142
143
144
145
146
147

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
148
149
150
    expected = torch.tensor(
        [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
    )
151
    assert torch.equal(output, expected)
152
153


154
def test_single_token_sequence(rejection_sampler):
155
156
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
157
    output_tokens = [[1, 2]]  # Single token with bonus token 2
158

159
    metadata = create_sampling_metadata(all_greedy=True)
160
    logits = create_logits_tensor(output_tokens)
161
162
163
164
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
165
166
167
168
169
170
171
172

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
173
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
174
    assert torch.equal(output, expected)
175
176


177
def test_empty_sequence(rejection_sampler):
178
    """Test handling empty sequence of speculated tokens"""
179
    spec_tokens: list[list[int]] = [[]]
180
    output_tokens = [[5]]  # Just the bonus token
181

182
    metadata = create_sampling_metadata(all_greedy=True)
183
    logits = create_logits_tensor(output_tokens)
184
185
186
187
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
188
189
190
191
192
193
194
195

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
196
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
197
    assert torch.equal(output, expected)
198
199


200
def test_multiple_mismatches(rejection_sampler):
201
202
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
203
    output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]]  # Mismatches in both sequences
204

205
    metadata = create_sampling_metadata(all_greedy=True)
206
    logits = create_logits_tensor(output_tokens)
207
    bonus_token_tensor = torch.tensor(
208
209
210
211
212
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
213
214
215
216
217
218
219
220
221

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
222
223
224
225
        [
            [1, 2, 7, PLACEHOLDER_TOKEN_ID],
            [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID],
        ],
226
227
228
        dtype=torch.int,
        device=logits.device,
    )
229
    assert torch.equal(output, expected)
230
231
232
233
234


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
235
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
236
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
237
238
239
240
241
242
243
244
        (
            [[1, 2], [3, 4]],
            [[1, 5, 6], [3, 4, 7]],
            [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
        ),  # Mixed matches
    ],
)
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected):
245
    """Parametrized test for various matching scenarios"""
246
    metadata = create_sampling_metadata(all_greedy=True)
247
    logits = create_logits_tensor(output_tokens)
248
249
250
251
252
253
    bonus_token_tensor = torch.tensor(
        [tokens[-1] for tokens in output_tokens], device=logits.device
    )
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
254
255
256
257
258
259
260
261

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
262
    expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
263
264
265
266
267
268
269
270
271
    assert torch.equal(output, expected_tensor)


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
272
273
274
275
276
277
278
279
280
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
281
    draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
282
283
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
284
285
286
287
288
289
    bonus_token_ids = torch.randint(
        low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
    )
    draft_token_ids = torch.randint(
        low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
    )
290
291
292
293
294
295
296

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
            i: torch.Generator(device=DEVICE).manual_seed(i)
297
298
            for i in range(batch_size)
            if seeded_mask[i]
299
300
        }

301
302
303
304
        temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
        sampling_metadata = create_sampling_metadata(
            all_greedy=False, temperature=temperature, generators=seeded_seqs
        )
305
        spec_decode_metadata = SpecDecodeMetadata.make_dummy(
306
307
            draft_token_ids.tolist(), device=DEVICE
        )
308
309
310
311
312
313
314
        rep_result = rejection_sampler(
            spec_decode_metadata,
            draft_probs=draft_probs,
            target_logits=target_logits,
            bonus_token_ids=bonus_token_ids,
            sampling_metadata=sampling_metadata,
        )
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352

        results.append(rep_result)

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
    torch.set_default_device(DEVICE)
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
353
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1)
354
355
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
356
357
358
359
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
360

361
362
363
364
365
366
367
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
368
369
            draft_probs, target_logits, k, vocab_size, num_samples
        )
370
371
372
        rej_sample_probs = rej_sample_probs.to(DEVICE)

        # Average distance from reference probs.
373
374
375
376
377
        reference_vs_rejsample_dist = (
            torch.dist(reference_probs, rej_sample_probs).item()
            / reference_probs.shape[0]
        )
        target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item()
378
379
380
381
382

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
383
384
            distance_wrt_target
        )
385
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
386
387
            distance_wrt_reference
        )
388

389
390
391
392
393
394
395
396
        print(
            f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
            f"{reference_vs_rejsample_dist=:.05f}"
        )
        print(
            f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
            f"{relative_change_in_distance_wrt_reference=:.02f}"
        )
397
398

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
399
400
        distance_wrt_target
    )
401
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
402
403
        distance_wrt_reference
    )
404
405

    expected_improvement_multiplier = 20
406
407
408
409
    assert (
        relative_change_in_distance_wrt_target
        > relative_change_in_distance_wrt_reference * expected_improvement_multiplier
    )
410
411
412
413
414
415
416
417


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
418
    target_logits: torch.Tensor,
419
420
421
422
423
424
425
426
427
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
428
        target_logits: Target logits.
429
430
431
432
433
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
434
435
436
    rejection_sampler = RejectionSampler()
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
437
    draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
438

439
440
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
441
442

    # Randomly sample draft token ids from draft probs.
443
444
445
    draft_token_ids = torch.multinomial(
        draft_probs[:, 0, :], num_samples=k, replacement=True
    ).reshape(num_samples, k)
446
    draft_probs = draft_probs.view(num_tokens, vocab_size)
447
448

    # Bonus tokens not used but required.
449
450
451
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
        num_samples, 1
    )
452

453
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
454
455
456
    sampling_metadata = create_sampling_metadata(
        all_greedy=False, temperature=temperature
    )
457
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
458
459
        draft_token_ids.tolist(), device=bonus_token_ids.device
    )
460
461
462
463
464
465
466
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )
467
468
    output_token_ids = output_token_ids[:, :-1].flatten()

469
470
471
472
473
474
    hist = torch.histogram(
        output_token_ids.to(dtype=torch.float, device="cpu"),
        bins=vocab_size,
        range=(0, vocab_size),
        density=True,
    )
475
476

    return hist.hist
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491


def _test_masked_logits(
    rejection_sampler,
    batch_size: int,
    num_draft_tokens: int,
    vocab_size: int,
    target_logits: torch.Tensor,
    unmasked_indices: torch.Tensor,
    sampling_metadata: SamplingMetadata,
):
    # Set up test parameters
    num_tokens = batch_size * num_draft_tokens

    # Create random draft probabilities.
492
493
494
    draft_probs = torch.rand(
        (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
    )
495
496
497
498
499
500
501
502
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
    draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
    draft_token_ids = draft_token_ids.tolist()

    # Bonus tokens not used but required
503
    bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

    # Create spec decode metadata
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids,
        device=DEVICE,
    )

    # Run rejection sampling
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )

    # Remove bonus tokens and reshape
    output_token_ids = output_token_ids[:, :-1].flatten().tolist()

    # Check that all sampled tokens are within the unmasked indices.
    for i in range(num_tokens):
        token_id = output_token_ids[i]
        if token_id == PLACEHOLDER_TOKEN_ID:
            continue
        assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
    """Test rejection sampling with top-k sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Randomly create top-k indices.
    top_k_indices = [
541
        torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    ]
    top_k_indices = torch.stack(top_k_indices)

    # Create logits with the uniform distribution.
    target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)

    # Increment the logits for top-k indices, a little bit more than the other
    # ones. If the masking is effective, the non-topk indices will never be
    # sampled despite the small difference in logits.
    for i in range(num_tokens):
        target_logits[i, top_k_indices[i]] += 0.1

    # Create sampling metadata
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
559
        top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_k_indices,
        sampling_metadata=sampling_metadata,
    )


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
    """Test rejection sampling with top-p sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Create logits with the uniform distribution.
    target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    rescaled_logits = target_logits / temperature

    logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - top_p
    # at least one
    top_p_mask[:, -1] = False

    # Get the top-p indices.
    top_p_indices = []
    for i in range(num_tokens):
        top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

    # Create sampling metadata
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
602
        top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
603
604
605
606
607
608
609
610
611
612
613
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_p_indices,
        sampling_metadata=sampling_metadata,
    )