test_rejection_sampler.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional
4
5
6

import pytest
import torch
7
import torch.nn.functional as F
8

9
from vllm.platforms import current_platform
10
from vllm.v1.sample.metadata import SamplingMetadata
11
12
13
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
                                              RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
14

15
DEVICE = current_platform.device_type
16

17
18

@pytest.fixture
19
def rejection_sampler():
20
21
22
    return RejectionSampler()


23
def create_logits_tensor(output_token_ids: list[list[int]],
24
                         vocab_size: int = 100) -> torch.Tensor:
25
    """Helper function to create logits tensor that
26
       will produce desired token ids on argmax"""
27
    token_ids = [tokens[:-1] for tokens in output_token_ids]
28
29
30
31
32
33
34
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
35
36
37
    return logits


38
def create_sampling_metadata(
39
40
    all_greedy: bool,
    temperature: Optional[torch.Tensor] = None,
41
42
    top_k: Optional[torch.Tensor] = None,
    top_p: Optional[torch.Tensor] = None,
43
44
    generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
45
46
    """Create a v1 sampling metadata object with all_greedy set
        to the given value. Either all greedy or all random sampling
47
48
49
        is used.
    """
    generators = generators or {}
50
51
52
53
54
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

55
    return SamplingMetadata(
56
        temperature=temperature,
57
58
        all_greedy=all_greedy,
        all_random=not all_greedy,
59
60
        top_p=top_p,
        top_k=top_k,
61
62
        min_p=torch.empty(1, ),
        generators=generators,
63
64
65
66
67
68
69
        max_num_logprobs=0,
        no_penalties=False,
        prompt_token_ids=None,
        frequency_penalties=torch.tensor([]),
        presence_penalties=torch.tensor([]),
        repetition_penalties=torch.tensor([]),
        output_token_ids=[],
70
        min_tokens={},
71
        logit_bias=[None],
72
        allowed_token_ids_mask=None,
73
        bad_words_token_ids={},
74
75
76
    )


77
########################### Tests for Greedy Sampling ###################
78
def test_perfect_match(rejection_sampler):
79
80
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
81
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
82

83
    metadata = create_sampling_metadata(all_greedy=True)
84
    logits = create_logits_tensor(output_tokens)
85
86
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
87
88
89
90
91
92
93
94
95
96
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
97
98
99
    expected = torch.tensor([[1, 2, 3, 4]],
                            dtype=torch.int,
                            device=logits.device)
100
    assert torch.equal(output, expected)
101
102


103
def test_early_mismatch(rejection_sampler):
104
105
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
106
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
107

108
    metadata = create_sampling_metadata(all_greedy=True)
109
    logits = create_logits_tensor(output_tokens)
110
111
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
127
    assert torch.equal(output, expected)
128
129


130
def test_multiple_sequences(rejection_sampler):
131
132
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
133
134
    output_tokens = [[1, 2, 5], [3,
                                 4]]  # Two sequences with bonus tokens 5 and 4
135

136
    metadata = create_sampling_metadata(all_greedy=True)
137
    logits = create_logits_tensor(output_tokens)
138
139
    bonus_token_tensor = torch.tensor(
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
140
141
142
143
144
145
146
147
148
149
150
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
151
152
                            dtype=torch.int,
                            device=logits.device)
153
    assert torch.equal(output, expected)
154
155


156
def test_single_token_sequence(rejection_sampler):
157
158
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
159
    output_tokens = [[1, 2]]  # Single token with bonus token 2
160

161
    metadata = create_sampling_metadata(all_greedy=True)
162
    logits = create_logits_tensor(output_tokens)
163
164
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
165
166
167
168
169
170
171
172
173
174
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
175
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
176
    assert torch.equal(output, expected)
177
178


179
def test_empty_sequence(rejection_sampler):
180
    """Test handling empty sequence of speculated tokens"""
181
    spec_tokens: list[list[int]] = [[]]
182
    output_tokens = [[5]]  # Just the bonus token
183

184
    metadata = create_sampling_metadata(all_greedy=True)
185
    logits = create_logits_tensor(output_tokens)
186
187
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
188
189
190
191
192
193
194
195
196
197
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
198
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
199
    assert torch.equal(output, expected)
200
201


202
def test_multiple_mismatches(rejection_sampler):
203
204
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
205
206
    output_tokens = [[1, 2, 7, 6], [4, 8, 6,
                                    9]]  # Mismatches in both sequences
207

208
    metadata = create_sampling_metadata(all_greedy=True)
209
    logits = create_logits_tensor(output_tokens)
210
211
    bonus_token_tensor = torch.tensor(
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 2, 7, PLACEHOLDER_TOKEN_ID],
         [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
228
    assert torch.equal(output, expected)
229
230
231
232
233


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
234
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
235
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
236
        ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
237
         [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]),  # Mixed matches
238
    ])
239
240
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
                            expected):
241
    """Parametrized test for various matching scenarios"""
242
    metadata = create_sampling_metadata(all_greedy=True)
243
    logits = create_logits_tensor(output_tokens)
244
245
    bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
                                      device=logits.device)
246
247
248
249
250
251
252
253
254
255
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
256
257
258
    expected_tensor = torch.tensor(expected,
                                   dtype=torch.int,
                                   device=logits.device)
259
260
261
262
263
264
265
266
267
    assert torch.equal(output, expected_tensor)


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
    draft_probs = torch.rand(num_tokens,
                             vocab_size,
                             dtype=torch.float32,
                             device=DEVICE)
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
283
284
285
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
286
287
                                    dtype=torch.int64,
                                    device=DEVICE)
288
289
290
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
291
292
                                    dtype=torch.int64,
                                    device=DEVICE)
293
294
295
296
297
298
299
300
301
302

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
            i: torch.Generator(device=DEVICE).manual_seed(i)
            for i in range(batch_size) if seeded_mask[i]
        }

303
304
305
        temperature = torch.ones(batch_size,
                                 dtype=torch.float32,
                                 device=DEVICE)
306
        sampling_metadata = create_sampling_metadata(all_greedy=False,
307
                                                     temperature=temperature,
308
                                                     generators=seeded_seqs)
309
310
311
312
313
314
315
316
317
        spec_decode_metadata = SpecDecodeMetadata.make_dummy(
            draft_token_ids.tolist(), device=DEVICE)
        rep_result = rejection_sampler(
            spec_decode_metadata,
            draft_probs=draft_probs,
            target_logits=target_logits,
            bonus_token_ids=bonus_token_ids,
            sampling_metadata=sampling_metadata,
        )
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

        results.append(rep_result)

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
    torch.set_default_device(DEVICE)
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
356
357
358
359
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
                            dim=-1)
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
360
361
362
363
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
364

365
366
367
368
369
370
371
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
372
            draft_probs, target_logits, k, vocab_size, num_samples)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        rej_sample_probs = rej_sample_probs.to(DEVICE)

        # Average distance from reference probs.
        reference_vs_rejsample_dist = torch.dist(
            reference_probs,
            rej_sample_probs).item() / reference_probs.shape[0]
        target_vs_rejsample_dist = torch.dist(target_probs,
                                              rej_sample_probs).item()

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
            distance_wrt_target)
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
            distance_wrt_reference)

        print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
              f"{reference_vs_rejsample_dist=:.05f}")
        print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
              f"{relative_change_in_distance_wrt_reference=:.02f}")

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
        distance_wrt_target)
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
        distance_wrt_reference)

    expected_improvement_multiplier = 20
    assert (relative_change_in_distance_wrt_target
            > relative_change_in_distance_wrt_reference *
            expected_improvement_multiplier)


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
412
    target_logits: torch.Tensor,
413
414
415
416
417
418
419
420
421
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
422
        target_logits: Target logits.
423
424
425
426
427
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
428
429
430
    rejection_sampler = RejectionSampler()
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
431
432
433
    draft_probs = draft_probs.reshape(1, 1,
                                      vocab_size).repeat(num_samples, k, 1)

434
435
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
436
437
438
439
440
441

    # Randomly sample draft token ids from draft probs.
    draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
                                        num_samples=k,
                                        replacement=True).reshape(
                                            num_samples, k)
442
    draft_probs = draft_probs.view(num_tokens, vocab_size)
443
444
445
446
447

    # Bonus tokens not used but required.
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
                                  device=DEVICE).repeat(num_samples, 1)

448
449
450
451
452
453
454
455
456
457
458
459
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(all_greedy=False,
                                                 temperature=temperature)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids.tolist(), device=bonus_token_ids.device)
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )
460
461
462
463
464
465
466
467
468
    output_token_ids = output_token_ids[:, :-1].flatten()

    hist = torch.histogram(output_token_ids.to(dtype=torch.float,
                                               device="cpu"),
                           bins=vocab_size,
                           range=(0, vocab_size),
                           density=True)

    return hist.hist
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612


def _test_masked_logits(
    rejection_sampler,
    batch_size: int,
    num_draft_tokens: int,
    vocab_size: int,
    target_logits: torch.Tensor,
    unmasked_indices: torch.Tensor,
    sampling_metadata: SamplingMetadata,
):
    # Set up test parameters
    num_tokens = batch_size * num_draft_tokens

    # Create random draft probabilities.
    draft_probs = torch.rand((num_tokens, vocab_size),
                             dtype=torch.float32,
                             device=DEVICE)
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
    draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
    draft_token_ids = draft_token_ids.tolist()

    # Bonus tokens not used but required
    bonus_token_ids = torch.zeros((batch_size, 1),
                                  dtype=torch.int64,
                                  device=DEVICE)

    # Create spec decode metadata
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids,
        device=DEVICE,
    )

    # Run rejection sampling
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )

    # Remove bonus tokens and reshape
    output_token_ids = output_token_ids[:, :-1].flatten().tolist()

    # Check that all sampled tokens are within the unmasked indices.
    for i in range(num_tokens):
        token_id = output_token_ids[i]
        if token_id == PLACEHOLDER_TOKEN_ID:
            continue
        assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
    """Test rejection sampling with top-k sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Randomly create top-k indices.
    top_k_indices = [
        torch.randperm(vocab_size, device=DEVICE)[:top_k]
        for _ in range(num_tokens)
    ]
    top_k_indices = torch.stack(top_k_indices)

    # Create logits with the uniform distribution.
    target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)

    # Increment the logits for top-k indices, a little bit more than the other
    # ones. If the masking is effective, the non-topk indices will never be
    # sampled despite the small difference in logits.
    for i in range(num_tokens):
        target_logits[i, top_k_indices[i]] += 0.1

    # Create sampling metadata
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
        top_k=torch.tensor([top_k] * batch_size,
                           device=DEVICE,
                           dtype=torch.int64),
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_k_indices,
        sampling_metadata=sampling_metadata,
    )


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
    """Test rejection sampling with top-p sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Create logits with the uniform distribution.
    target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    rescaled_logits = target_logits / temperature

    logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - top_p
    # at least one
    top_p_mask[:, -1] = False

    # Get the top-p indices.
    top_p_indices = []
    for i in range(num_tokens):
        top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

    # Create sampling metadata
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
        top_p=torch.tensor([top_p] * batch_size,
                           device=DEVICE,
                           dtype=torch.float32),
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_p_indices,
        sampling_metadata=sampling_metadata,
    )