test_rejection_sampler.py 27.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any
4
from unittest.mock import Mock
5
6
7

import pytest
import torch
8
import torch.nn.functional as F
9

10
from tests.v1.sample.utils import create_allowed_token_ids
11
from vllm.platforms import current_platform
12
from vllm.v1.sample.logits_processor import LogitsProcessors
13
from vllm.v1.sample.metadata import SamplingMetadata
14
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
15
from vllm.v1.sample.sampler import Sampler, SamplerOutput
16
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
17

18
DEVICE = current_platform.device_type
19

20
21

@pytest.fixture
22
def rejection_sampler():
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    mock_sampler = Mock(spec=Sampler)
    mock_sampler.logprobs_mode = "raw_logprobs"
    return RejectionSampler(mock_sampler)


def mock_sampler_output(
    rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
):
    rejection_sampler.sampler.return_value = SamplerOutput(
        sampled_token_ids=bonus_token_ids, logprobs_tensors=None
    )


def create_spec_decode_metadata(
    spec_tokens: list[list[int]], logits: torch.Tensor
) -> SpecDecodeMetadata:
    metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
    metadata.target_logits_indices = torch.arange(logits.shape[0])
    # Output bonus token ids are mocked, so the bonus logit indices should
    # be empty.
    metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
    return metadata
45
46


47
def create_logits_tensor(
48
49
    output_token_ids: list[list[int]],
    vocab_size: int = 100,
50
    token_idx_to_override: int | None = None,
51
) -> torch.Tensor:
52
    """Helper function to create logits tensor that
53
    will produce desired token ids on argmax"""
54
    token_ids = [tokens[:-1] for tokens in output_token_ids]
55
56
57
58
59
60
61
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
62
63
    if token_idx_to_override:
        logits[:, token_idx_to_override] = 99.0
64
65
66
    return logits


67
def create_sampling_metadata(
68
    all_greedy: bool,
69
70
71
72
73
74
75
76
77
78
79
80
    output_token_ids: list[list[int]] | None = None,
    prompt_token_ids: torch.Tensor | None = None,
    spec_token_ids: torch.Tensor | None = None,
    temperature: torch.Tensor | None = None,
    top_k: torch.Tensor | None = None,
    top_p: torch.Tensor | None = None,
    generators: dict[int, Any] | None = None,
    frequency_penalties: list[float] | None = None,
    presence_penalties: list[float] | None = None,
    repetition_penalties: list[float] | None = None,
    bad_words_token_ids: dict[int, list[list[int]]] | None = None,
    allowed_token_ids_mask: torch.Tensor | None = None,
81
) -> SamplingMetadata:
82
    """Create a v1 sampling metadata object with all_greedy set
83
84
    to the given value. Either all greedy or all random sampling
    is used.
85
86
    """
    generators = generators or {}
87
88
89
90
91
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    if any([frequency_penalties, presence_penalties, repetition_penalties]):
        no_penalties = False

        assert output_token_ids
        assert len(output_token_ids) > 0

        frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
        presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
        repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
    else:
        no_penalties = True
        frequency_penalties = torch.tensor([])
        presence_penalties = torch.tensor([])
        repetition_penalties = torch.tensor([])

107
    return SamplingMetadata(
108
        temperature=temperature,
109
110
        all_greedy=all_greedy,
        all_random=not all_greedy,
111
112
        top_p=top_p,
        top_k=top_k,
113
        generators=generators,
114
        max_num_logprobs=None,
115
116
117
118
119
120
121
122
123
        no_penalties=no_penalties,
        prompt_token_ids=prompt_token_ids,
        frequency_penalties=frequency_penalties,
        presence_penalties=presence_penalties,
        repetition_penalties=repetition_penalties,
        output_token_ids=[] if output_token_ids is None else output_token_ids,
        spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
        allowed_token_ids_mask=allowed_token_ids_mask,
        bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
124
        logitsprocs=LogitsProcessors(),
125
126
127
    )


128
########################### Tests for Greedy Sampling ###################
129
def test_perfect_match(rejection_sampler):
130
131
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
132
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
133

134
    metadata = create_sampling_metadata(all_greedy=True)
135
    logits = create_logits_tensor(output_tokens)
136
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
137
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
138

139
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
140
141
142
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
143
        logits=logits,
144
145
        sampling_metadata=metadata,
    )
146
    expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
147
    assert torch.equal(output.sampled_token_ids, expected)
148
149


150
def test_early_mismatch(rejection_sampler):
151
152
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
153
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
154

155
    metadata = create_sampling_metadata(all_greedy=True)
156
    logits = create_logits_tensor(output_tokens)
157
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
158
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
159

160
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
161
162
163
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
164
        logits=logits,
165
166
167
168
169
170
171
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
172
    assert torch.equal(output.sampled_token_ids, expected)
173
174


175
def test_multiple_sequences(rejection_sampler):
176
177
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
178
    output_tokens = [[1, 2, 5], [3, 4]]  # Two sequences with bonus tokens 5 and 4
179

180
    metadata = create_sampling_metadata(all_greedy=True)
181
    logits = create_logits_tensor(output_tokens)
182
    bonus_token_tensor = torch.tensor(
183
184
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
185
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
186

187
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
188
189
190
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
191
        logits=logits,
192
193
        sampling_metadata=metadata,
    )
194
195
196
    expected = torch.tensor(
        [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
    )
197
    assert torch.equal(output.sampled_token_ids, expected)
198
199


200
def test_single_token_sequence(rejection_sampler):
201
202
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
203
    output_tokens = [[1, 2]]  # Single token with bonus token 2
204

205
    metadata = create_sampling_metadata(all_greedy=True)
206
    logits = create_logits_tensor(output_tokens)
207
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
208
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
209

210
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
211
212
213
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
214
        logits=logits,
215
216
        sampling_metadata=metadata,
    )
217
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
218
    assert torch.equal(output.sampled_token_ids, expected)
219
220


221
def test_empty_sequence(rejection_sampler):
222
    """Test handling empty sequence of speculated tokens"""
223
    spec_tokens: list[list[int]] = [[]]
224
    output_tokens = [[5]]  # Just the bonus token
225

226
    metadata = create_sampling_metadata(all_greedy=True)
227
    logits = create_logits_tensor(output_tokens)
228
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
229
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
230

231
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
232
233
234
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
235
        logits=logits,
236
237
        sampling_metadata=metadata,
    )
238
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
239
    assert torch.equal(output.sampled_token_ids, expected)
240
241


242
def test_multiple_mismatches(rejection_sampler):
243
244
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
245
    output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]]  # Mismatches in both sequences
246

247
    metadata = create_sampling_metadata(all_greedy=True)
248
    logits = create_logits_tensor(output_tokens)
249
    bonus_token_tensor = torch.tensor(
250
251
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
252
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
253

254
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
255
256
257
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
258
        logits=logits,
259
260
261
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
262
263
264
265
        [
            [1, 2, 7, PLACEHOLDER_TOKEN_ID],
            [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID],
        ],
266
267
268
        dtype=torch.int,
        device=logits.device,
    )
269
    assert torch.equal(output.sampled_token_ids, expected)
270
271
272
273
274


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
275
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
276
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
277
278
279
280
281
282
283
284
        (
            [[1, 2], [3, 4]],
            [[1, 5, 6], [3, 4, 7]],
            [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
        ),  # Mixed matches
    ],
)
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected):
285
    """Parametrized test for various matching scenarios"""
286
    metadata = create_sampling_metadata(all_greedy=True)
287
    logits = create_logits_tensor(output_tokens)
288
289
290
    bonus_token_tensor = torch.tensor(
        [tokens[-1] for tokens in output_tokens], device=logits.device
    )
291
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
292

293
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
294
295
296
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
297
        logits=logits,
298
299
        sampling_metadata=metadata,
    )
300
    expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
301
    assert torch.equal(output.sampled_token_ids, expected_tensor)
302
303
304
305
306
307
308
309


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
310
311
312
313
314
315
316
317
318
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
319
    draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
320
321
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
322
323
324
325
326
327
    bonus_token_ids = torch.randint(
        low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
    )
    draft_token_ids = torch.randint(
        low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
    )
328
329
330
331
332
333
334

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
            i: torch.Generator(device=DEVICE).manual_seed(i)
335
336
            for i in range(batch_size)
            if seeded_mask[i]
337
338
        }

339
340
341
342
        temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
        sampling_metadata = create_sampling_metadata(
            all_greedy=False, temperature=temperature, generators=seeded_seqs
        )
343
344
        spec_decode_metadata = create_spec_decode_metadata(
            draft_token_ids.tolist(), target_logits
345
        )
346
347

        mock_sampler_output(rejection_sampler, bonus_token_ids)
348
349
        rep_result = rejection_sampler(
            spec_decode_metadata,
350
351
            draft_probs=None,
            logits=target_logits,
352
353
            sampling_metadata=sampling_metadata,
        )
354

355
        results.append(rep_result.sampled_token_ids)
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
    torch.set_default_device(DEVICE)
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
392
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1)
393
394
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
395
396
397
398
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
399

400
401
402
403
404
405
406
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
407
408
            draft_probs, target_logits, k, vocab_size, num_samples
        )
409
410
411
        rej_sample_probs = rej_sample_probs.to(DEVICE)

        # Average distance from reference probs.
412
413
414
415
416
        reference_vs_rejsample_dist = (
            torch.dist(reference_probs, rej_sample_probs).item()
            / reference_probs.shape[0]
        )
        target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item()
417
418
419
420
421

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
422
423
            distance_wrt_target
        )
424
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
425
426
            distance_wrt_reference
        )
427

428
429
430
431
432
433
434
435
        print(
            f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
            f"{reference_vs_rejsample_dist=:.05f}"
        )
        print(
            f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
            f"{relative_change_in_distance_wrt_reference=:.02f}"
        )
436
437

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
438
439
        distance_wrt_target
    )
440
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
441
442
        distance_wrt_reference
    )
443
444

    expected_improvement_multiplier = 20
445
446
447
448
    assert (
        relative_change_in_distance_wrt_target
        > relative_change_in_distance_wrt_reference * expected_improvement_multiplier
    )
449
450
451
452
453
454
455
456


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
457
    target_logits: torch.Tensor,
458
459
460
461
462
463
464
465
466
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
467
        target_logits: Target logits.
468
469
470
471
472
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
473
474
475
    mock_sampler = Mock(spec=Sampler)
    mock_sampler.logprobs_mode = "raw_logprobs"
    rejection_sampler = RejectionSampler(mock_sampler)
476
477
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
478
    draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
479

480
481
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
482
483

    # Randomly sample draft token ids from draft probs.
484
485
486
    draft_token_ids = torch.multinomial(
        draft_probs[:, 0, :], num_samples=k, replacement=True
    ).reshape(num_samples, k)
487
    draft_probs = draft_probs.view(num_tokens, vocab_size)
488
489

    # Bonus tokens not used but required.
490
491
492
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
        num_samples, 1
    )
493

494
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
495
496
497
    sampling_metadata = create_sampling_metadata(
        all_greedy=False, temperature=temperature
    )
498
499
    spec_decode_metadata = create_spec_decode_metadata(
        draft_token_ids.tolist(), target_logits
500
    )
501
502
503

    mock_sampler_output(rejection_sampler, bonus_token_ids)
    sampler_output = rejection_sampler(
504
505
        spec_decode_metadata,
        draft_probs=draft_probs,
506
        logits=target_logits,
507
508
        sampling_metadata=sampling_metadata,
    )
509
    output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
510

511
512
513
514
515
516
    hist = torch.histogram(
        output_token_ids.to(dtype=torch.float, device="cpu"),
        bins=vocab_size,
        range=(0, vocab_size),
        density=True,
    )
517
518

    return hist.hist
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533


def _test_masked_logits(
    rejection_sampler,
    batch_size: int,
    num_draft_tokens: int,
    vocab_size: int,
    target_logits: torch.Tensor,
    unmasked_indices: torch.Tensor,
    sampling_metadata: SamplingMetadata,
):
    # Set up test parameters
    num_tokens = batch_size * num_draft_tokens

    # Create random draft probabilities.
534
535
536
    draft_probs = torch.rand(
        (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
    )
537
538
539
540
541
542
543
544
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
    draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
    draft_token_ids = draft_token_ids.tolist()

    # Bonus tokens not used but required
545
    bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
546
547

    # Create spec decode metadata
548
    spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
549
550

    # Run rejection sampling
551
552
    mock_sampler_output(rejection_sampler, bonus_token_ids)
    output = rejection_sampler(
553
554
        spec_decode_metadata,
        draft_probs=draft_probs,
555
        logits=target_logits,
556
557
558
559
        sampling_metadata=sampling_metadata,
    )

    # Remove bonus tokens and reshape
560
    output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579

    # Check that all sampled tokens are within the unmasked indices.
    for i in range(num_tokens):
        token_id = output_token_ids[i]
        if token_id == PLACEHOLDER_TOKEN_ID:
            continue
        assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
    """Test rejection sampling with top-k sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Randomly create top-k indices.
    top_k_indices = [
580
        torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    ]
    top_k_indices = torch.stack(top_k_indices)

    # Create logits with the uniform distribution.
    target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)

    # Increment the logits for top-k indices, a little bit more than the other
    # ones. If the masking is effective, the non-topk indices will never be
    # sampled despite the small difference in logits.
    for i in range(num_tokens):
        target_logits[i, top_k_indices[i]] += 0.1

    # Create sampling metadata
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
598
        top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_k_indices,
        sampling_metadata=sampling_metadata,
    )


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
    """Test rejection sampling with top-p sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Create logits with the uniform distribution.
    target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    rescaled_logits = target_logits / temperature

    logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - top_p
    # at least one
    top_p_mask[:, -1] = False

    # Get the top-p indices.
    top_p_indices = []
    for i in range(num_tokens):
        top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

    # Create sampling metadata
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
641
        top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
642
643
644
645
646
647
648
649
650
651
652
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_p_indices,
        sampling_metadata=sampling_metadata,
    )
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677


########################### Tests for Logit Processors ###################
def test_frequency_penalties(rejection_sampler):
    """Test rejection sampling with frequency penalties"""
    spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
    output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]]  # 1, 7 and 1 are the bonus tokens

    num_requsts = len(spec_tokens)
    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[2], [3], [4]],
        spec_token_ids=spec_tokens,
        prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
        frequency_penalties=[1.5, 1.5, 0.7],
        presence_penalties=[0.0] * num_requsts,
        repetition_penalties=[1.0] * num_requsts,
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
678
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
679
680
681
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
682
        logits=logits,
683
684
685
686
687
688
689
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
        dtype=torch.int,
        device=logits.device,
    )
690
    assert torch.equal(output.sampled_token_ids, expected)
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719


def test_bad_words(rejection_sampler):
    """Test rejection sampling with bad words constraints"""
    spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
    output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]

    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[2], [3], [4]],
        spec_token_ids=spec_tokens,
        bad_words_token_ids={
            0: [
                [
                    2,
                ]
            ],
            1: [
                [
                    2,
                ]
            ],
            # Do not apply bad words to the last request
        },
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
720
721
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
722
723
724
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
725
        logits=logits,
726
727
728
729
730
731
732
733
        sampling_metadata=metadata,
    )

    expected = torch.tensor(
        [[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
        dtype=torch.int,
        device=logits.device,
    )
734
    assert torch.equal(output.sampled_token_ids, expected)
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766


def test_allowed_token_ids(rejection_sampler):
    """Test rejection sampling with allowed token ids"""
    spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
    output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
    # Not allowed tokens:
    # 0: 0-4
    # 1: 1-5
    # 2: 2-6
    num_allowed_token_ids = 5

    # Use the token 15 as the sampler choose if a token rejected
    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)

    batch_size = len(output_tokens)
    _, vocab_size = logits.size()
    mask = create_allowed_token_ids(
        batch_size=batch_size,
        vocab_size=vocab_size,
        num_allowed_token_ids=num_allowed_token_ids,
        device=logits.device,
    )
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[], [], []],
        spec_token_ids=spec_tokens,
        allowed_token_ids_mask=mask,
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
767
768
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
769
770
771
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
772
        logits=logits,
773
774
775
776
777
778
779
780
        sampling_metadata=metadata,
    )

    expected = torch.tensor(
        [[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
        dtype=torch.int,
        device=logits.device,
    )
781
    assert torch.equal(output.sampled_token_ids, expected)