test_rejection_sampler.py 31.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any
4
from unittest.mock import Mock
5
6
7

import pytest
import torch
8
import torch.nn.functional as F
9

10
from tests.v1.sample.utils import create_allowed_token_ids
11
from vllm.platforms import current_platform
12
from vllm.v1.sample.logits_processor import LogitsProcessors
13
from vllm.v1.sample.metadata import SamplingMetadata
14
15
16
17
18
from vllm.v1.sample.rejection_sampler import (
    PLACEHOLDER_TOKEN_ID,
    RejectionSampler,
    sample_recovered_tokens,
)
19
from vllm.v1.sample.sampler import Sampler, SamplerOutput
20
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
21

22
DEVICE = current_platform.device_type
23

24
25

@pytest.fixture
26
def rejection_sampler():
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    mock_sampler = Mock(spec=Sampler)
    mock_sampler.logprobs_mode = "raw_logprobs"
    return RejectionSampler(mock_sampler)


def mock_sampler_output(
    rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
):
    rejection_sampler.sampler.return_value = SamplerOutput(
        sampled_token_ids=bonus_token_ids, logprobs_tensors=None
    )


def create_spec_decode_metadata(
    spec_tokens: list[list[int]], logits: torch.Tensor
) -> SpecDecodeMetadata:
    metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
    metadata.target_logits_indices = torch.arange(logits.shape[0])
    # Output bonus token ids are mocked, so the bonus logit indices should
    # be empty.
    metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
    return metadata
49
50


51
def create_logits_tensor(
52
53
    output_token_ids: list[list[int]],
    vocab_size: int = 100,
54
    token_idx_to_override: int | None = None,
55
) -> torch.Tensor:
56
    """Helper function to create logits tensor that
57
    will produce desired token ids on argmax"""
58
    token_ids = [tokens[:-1] for tokens in output_token_ids]
59
60
61
62
63
64
65
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
66
67
    if token_idx_to_override:
        logits[:, token_idx_to_override] = 99.0
68
69
70
    return logits


71
def create_sampling_metadata(
72
    all_greedy: bool,
73
74
75
76
77
78
79
80
81
82
83
84
    output_token_ids: list[list[int]] | None = None,
    prompt_token_ids: torch.Tensor | None = None,
    spec_token_ids: torch.Tensor | None = None,
    temperature: torch.Tensor | None = None,
    top_k: torch.Tensor | None = None,
    top_p: torch.Tensor | None = None,
    generators: dict[int, Any] | None = None,
    frequency_penalties: list[float] | None = None,
    presence_penalties: list[float] | None = None,
    repetition_penalties: list[float] | None = None,
    bad_words_token_ids: dict[int, list[list[int]]] | None = None,
    allowed_token_ids_mask: torch.Tensor | None = None,
85
) -> SamplingMetadata:
86
    """Create a v1 sampling metadata object with all_greedy set
87
88
    to the given value. Either all greedy or all random sampling
    is used.
89
90
    """
    generators = generators or {}
91
92
93
94
95
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    if any([frequency_penalties, presence_penalties, repetition_penalties]):
        no_penalties = False

        assert output_token_ids
        assert len(output_token_ids) > 0

        frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
        presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
        repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
    else:
        no_penalties = True
        frequency_penalties = torch.tensor([])
        presence_penalties = torch.tensor([])
        repetition_penalties = torch.tensor([])

111
    return SamplingMetadata(
112
        temperature=temperature,
113
114
        all_greedy=all_greedy,
        all_random=not all_greedy,
115
116
        top_p=top_p,
        top_k=top_k,
117
        generators=generators,
118
        max_num_logprobs=None,
119
120
121
122
123
124
125
126
127
        no_penalties=no_penalties,
        prompt_token_ids=prompt_token_ids,
        frequency_penalties=frequency_penalties,
        presence_penalties=presence_penalties,
        repetition_penalties=repetition_penalties,
        output_token_ids=[] if output_token_ids is None else output_token_ids,
        spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
        allowed_token_ids_mask=allowed_token_ids_mask,
        bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
128
        logitsprocs=LogitsProcessors(),
129
130
131
    )


132
########################### Tests for Greedy Sampling ###################
133
def test_perfect_match(rejection_sampler):
134
135
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
136
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
137

138
    metadata = create_sampling_metadata(all_greedy=True)
139
    logits = create_logits_tensor(output_tokens)
140
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
141
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
142

143
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
144
145
146
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
147
        logits=logits,
148
149
        sampling_metadata=metadata,
    )
150
    expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
151
    assert torch.equal(output.sampled_token_ids, expected)
152
153


154
def test_early_mismatch(rejection_sampler):
155
156
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
157
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
158

159
    metadata = create_sampling_metadata(all_greedy=True)
160
    logits = create_logits_tensor(output_tokens)
161
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
162
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
163

164
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
165
166
167
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
168
        logits=logits,
169
170
171
172
173
174
175
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
176
    assert torch.equal(output.sampled_token_ids, expected)
177
178


179
def test_multiple_sequences(rejection_sampler):
180
181
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
182
    output_tokens = [[1, 2, 5], [3, 4]]  # Two sequences with bonus tokens 5 and 4
183

184
    metadata = create_sampling_metadata(all_greedy=True)
185
    logits = create_logits_tensor(output_tokens)
186
    bonus_token_tensor = torch.tensor(
187
188
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
189
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
190

191
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
192
193
194
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
195
        logits=logits,
196
197
        sampling_metadata=metadata,
    )
198
199
200
    expected = torch.tensor(
        [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
    )
201
    assert torch.equal(output.sampled_token_ids, expected)
202
203


204
def test_single_token_sequence(rejection_sampler):
205
206
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
207
    output_tokens = [[1, 2]]  # Single token with bonus token 2
208

209
    metadata = create_sampling_metadata(all_greedy=True)
210
    logits = create_logits_tensor(output_tokens)
211
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
212
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
213

214
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
215
216
217
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
218
        logits=logits,
219
220
        sampling_metadata=metadata,
    )
221
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
222
    assert torch.equal(output.sampled_token_ids, expected)
223
224


225
def test_empty_sequence(rejection_sampler):
226
    """Test handling empty sequence of speculated tokens"""
227
    spec_tokens: list[list[int]] = [[]]
228
    output_tokens = [[5]]  # Just the bonus token
229

230
    metadata = create_sampling_metadata(all_greedy=True)
231
    logits = create_logits_tensor(output_tokens)
232
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
233
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
234

235
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
236
237
238
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
239
        logits=logits,
240
241
        sampling_metadata=metadata,
    )
242
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
243
    assert torch.equal(output.sampled_token_ids, expected)
244
245


246
def test_multiple_mismatches(rejection_sampler):
247
248
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
249
    output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]]  # Mismatches in both sequences
250

251
    metadata = create_sampling_metadata(all_greedy=True)
252
    logits = create_logits_tensor(output_tokens)
253
    bonus_token_tensor = torch.tensor(
254
255
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
256
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
257

258
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
259
260
261
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
262
        logits=logits,
263
264
265
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
266
267
268
269
        [
            [1, 2, 7, PLACEHOLDER_TOKEN_ID],
            [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID],
        ],
270
271
272
        dtype=torch.int,
        device=logits.device,
    )
273
    assert torch.equal(output.sampled_token_ids, expected)
274
275
276
277
278


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
279
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
280
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
281
282
283
284
285
286
287
288
        (
            [[1, 2], [3, 4]],
            [[1, 5, 6], [3, 4, 7]],
            [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
        ),  # Mixed matches
    ],
)
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected):
289
    """Parametrized test for various matching scenarios"""
290
    metadata = create_sampling_metadata(all_greedy=True)
291
    logits = create_logits_tensor(output_tokens)
292
293
294
    bonus_token_tensor = torch.tensor(
        [tokens[-1] for tokens in output_tokens], device=logits.device
    )
295
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
296

297
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
298
299
300
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
301
        logits=logits,
302
303
        sampling_metadata=metadata,
    )
304
    expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
305
    assert torch.equal(output.sampled_token_ids, expected_tensor)
306
307
308
309
310
311
312
313


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
314
315
316
317
318
319
320
321
322
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
323
    draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
324
325
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
326
327
328
329
330
331
    bonus_token_ids = torch.randint(
        low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
    )
    draft_token_ids = torch.randint(
        low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
    )
332
333
334
335
336
337
338

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
            i: torch.Generator(device=DEVICE).manual_seed(i)
339
340
            for i in range(batch_size)
            if seeded_mask[i]
341
342
        }

343
344
345
346
        temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
        sampling_metadata = create_sampling_metadata(
            all_greedy=False, temperature=temperature, generators=seeded_seqs
        )
347
348
        spec_decode_metadata = create_spec_decode_metadata(
            draft_token_ids.tolist(), target_logits
349
        )
350
351

        mock_sampler_output(rejection_sampler, bonus_token_ids)
352
353
        rep_result = rejection_sampler(
            spec_decode_metadata,
354
355
            draft_probs=None,
            logits=target_logits,
356
357
            sampling_metadata=sampling_metadata,
        )
358

359
        results.append(rep_result.sampled_token_ids)
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
    torch.set_default_device(DEVICE)
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
396
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1)
397
398
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
399
400
401
402
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
403

404
405
406
407
408
409
410
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
411
412
            draft_probs, target_logits, k, vocab_size, num_samples
        )
413
414
415
        rej_sample_probs = rej_sample_probs.to(DEVICE)

        # Average distance from reference probs.
416
417
418
419
420
        reference_vs_rejsample_dist = (
            torch.dist(reference_probs, rej_sample_probs).item()
            / reference_probs.shape[0]
        )
        target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item()
421
422
423
424
425

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
426
427
            distance_wrt_target
        )
428
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
429
430
            distance_wrt_reference
        )
431

432
433
434
435
436
437
438
439
        print(
            f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
            f"{reference_vs_rejsample_dist=:.05f}"
        )
        print(
            f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
            f"{relative_change_in_distance_wrt_reference=:.02f}"
        )
440
441

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
442
443
        distance_wrt_target
    )
444
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
445
446
        distance_wrt_reference
    )
447
448

    expected_improvement_multiplier = 20
449
450
451
452
    assert (
        relative_change_in_distance_wrt_target
        > relative_change_in_distance_wrt_reference * expected_improvement_multiplier
    )
453
454
455
456
457
458
459
460


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
461
    target_logits: torch.Tensor,
462
463
464
465
466
467
468
469
470
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
471
        target_logits: Target logits.
472
473
474
475
476
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
477
478
479
    mock_sampler = Mock(spec=Sampler)
    mock_sampler.logprobs_mode = "raw_logprobs"
    rejection_sampler = RejectionSampler(mock_sampler)
480
481
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
482
    draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
483

484
485
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
486
487

    # Randomly sample draft token ids from draft probs.
488
489
490
    draft_token_ids = torch.multinomial(
        draft_probs[:, 0, :], num_samples=k, replacement=True
    ).reshape(num_samples, k)
491
    draft_probs = draft_probs.view(num_tokens, vocab_size)
492
493

    # Bonus tokens not used but required.
494
495
496
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
        num_samples, 1
    )
497

498
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
499
500
501
    sampling_metadata = create_sampling_metadata(
        all_greedy=False, temperature=temperature
    )
502
503
    spec_decode_metadata = create_spec_decode_metadata(
        draft_token_ids.tolist(), target_logits
504
    )
505
506
507

    mock_sampler_output(rejection_sampler, bonus_token_ids)
    sampler_output = rejection_sampler(
508
509
        spec_decode_metadata,
        draft_probs=draft_probs,
510
        logits=target_logits,
511
512
        sampling_metadata=sampling_metadata,
    )
513
    output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
514

515
516
517
518
519
520
    hist = torch.histogram(
        output_token_ids.to(dtype=torch.float, device="cpu"),
        bins=vocab_size,
        range=(0, vocab_size),
        density=True,
    )
521
522

    return hist.hist
523
524


525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
def native_sample_recovered_tokens(
    max_spec_len: int,
    num_draft_tokens: list[int],
    cu_num_draft_tokens: torch.Tensor,  # [batch_size]
    draft_token_ids: torch.Tensor,  # [num_tokens]
    draft_probs: torch.Tensor | None,  # [num_tokens, vocab_size]
    target_probs: torch.Tensor,  # [num_tokens, vocab_size]
    sampling_metadata: SamplingMetadata,
    device: torch.device,
) -> torch.Tensor:
    batch_size = len(num_draft_tokens)
    vocab_size = target_probs.shape[-1]

    q = torch.empty(
        (batch_size, vocab_size),
        dtype=torch.float32,
        device=device,
    )
    q.exponential_()

    states = {
        i: generator.get_state()
        for i, generator in sampling_metadata.generators.items()
    }
    for i, generator in sampling_metadata.generators.items():
        # Do not generate random numbers for requests with no draft tokens.
        # This can be important for reproducibility.
        if num_draft_tokens[i] > 0:
            q[i].exponential_(generator=generator)

        # In order to generate the same exponential later, reset the CUDA RNG
        # state because RNG state advances after each call.
        generator.set_state(states[i])

    inv_q = q.reciprocal()

    out = torch.empty_like(draft_token_ids)

    for req_idx in range(batch_size):
        start_idx = 0 if req_idx == 0 else int(cu_num_draft_tokens[req_idx - 1].item())
        end_idx = int(cu_num_draft_tokens[req_idx].item())
        num_tokens = end_idx - start_idx

        for pos in range(max_spec_len):
            if pos >= num_tokens:
                continue
            token_idx = start_idx + pos

            if draft_probs is None:
                # prob is target_probs[token_idx] except draft_token_id is zeroed
                prob = target_probs[token_idx].clone()
                draft_token_id = draft_token_ids[token_idx]
                prob[draft_token_id] = 0.0
            else:
                prob = (target_probs[token_idx] - draft_probs[token_idx]).clamp_min_(
                    0.0
                )

            score = prob * inv_q[req_idx]
            recovered_id = torch.argmax(score, dim=-1)
            out[token_idx] = recovered_id
    return out


589
590
591
592
593
594
595
596
597
598
599
600
601
def _test_masked_logits(
    rejection_sampler,
    batch_size: int,
    num_draft_tokens: int,
    vocab_size: int,
    target_logits: torch.Tensor,
    unmasked_indices: torch.Tensor,
    sampling_metadata: SamplingMetadata,
):
    # Set up test parameters
    num_tokens = batch_size * num_draft_tokens

    # Create random draft probabilities.
602
603
604
    draft_probs = torch.rand(
        (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
    )
605
606
607
608
609
610
611
612
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
    draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
    draft_token_ids = draft_token_ids.tolist()

    # Bonus tokens not used but required
613
    bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
614
615

    # Create spec decode metadata
616
    spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
617
618

    # Run rejection sampling
619
620
    mock_sampler_output(rejection_sampler, bonus_token_ids)
    output = rejection_sampler(
621
622
        spec_decode_metadata,
        draft_probs=draft_probs,
623
        logits=target_logits,
624
625
626
627
        sampling_metadata=sampling_metadata,
    )

    # Remove bonus tokens and reshape
628
    output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

    # Check that all sampled tokens are within the unmasked indices.
    for i in range(num_tokens):
        token_id = output_token_ids[i]
        if token_id == PLACEHOLDER_TOKEN_ID:
            continue
        assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
    """Test rejection sampling with top-k sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Randomly create top-k indices.
    top_k_indices = [
648
        torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    ]
    top_k_indices = torch.stack(top_k_indices)

    # Create logits with the uniform distribution.
    target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)

    # Increment the logits for top-k indices, a little bit more than the other
    # ones. If the masking is effective, the non-topk indices will never be
    # sampled despite the small difference in logits.
    for i in range(num_tokens):
        target_logits[i, top_k_indices[i]] += 0.1

    # Create sampling metadata
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
666
        top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_k_indices,
        sampling_metadata=sampling_metadata,
    )


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
    """Test rejection sampling with top-p sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Create logits with the uniform distribution.
    target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    rescaled_logits = target_logits / temperature

    logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - top_p
    # at least one
    top_p_mask[:, -1] = False

    # Get the top-p indices.
    top_p_indices = []
    for i in range(num_tokens):
        top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

    # Create sampling metadata
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
709
        top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
710
711
712
713
714
715
716
717
718
719
720
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_p_indices,
        sampling_metadata=sampling_metadata,
    )
721
722
723
724
725
726
727
728


########################### Tests for Logit Processors ###################
def test_frequency_penalties(rejection_sampler):
    """Test rejection sampling with frequency penalties"""
    spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
    output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]]  # 1, 7 and 1 are the bonus tokens

Jiayi Yan's avatar
Jiayi Yan committed
729
    num_requests = len(spec_tokens)
730
731
732
733
734
735
736
    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[2], [3], [4]],
        spec_token_ids=spec_tokens,
        prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
        frequency_penalties=[1.5, 1.5, 0.7],
Jiayi Yan's avatar
Jiayi Yan committed
737
738
        presence_penalties=[0.0] * num_requests,
        repetition_penalties=[1.0] * num_requests,
739
740
741
742
743
744
745
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
746
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
747
748
749
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
750
        logits=logits,
751
752
753
754
755
756
757
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
        dtype=torch.int,
        device=logits.device,
    )
758
    assert torch.equal(output.sampled_token_ids, expected)
759
760
761


def test_bad_words(rejection_sampler):
762
763
764
765
766
    """Test rejection sampling with bad words constraints.

    This test applies bad words to non-consecutive requests (0 and 2, but not 1)
    to verify correct logit indexing when iterating over requests with bad words.
    """
767
    spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
768
    output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]]
769
770
771
772
773
774
775

    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[2], [3], [4]],
        spec_token_ids=spec_tokens,
        bad_words_token_ids={
776
777
778
            0: [[2]],
            # Request 1 has no bad words (to test non-consecutive request handling)
            2: [[2]],
779
780
781
782
783
        },
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
784
785
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
786
787
788
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
789
        logits=logits,
790
791
792
        sampling_metadata=metadata,
    )

793
794
795
    # Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
    # Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
    # Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
796
    expected = torch.tensor(
797
        [[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]],
798
799
800
        dtype=torch.int,
        device=logits.device,
    )
801
    assert torch.equal(output.sampled_token_ids, expected)
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833


def test_allowed_token_ids(rejection_sampler):
    """Test rejection sampling with allowed token ids"""
    spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
    output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
    # Not allowed tokens:
    # 0: 0-4
    # 1: 1-5
    # 2: 2-6
    num_allowed_token_ids = 5

    # Use the token 15 as the sampler choose if a token rejected
    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)

    batch_size = len(output_tokens)
    _, vocab_size = logits.size()
    mask = create_allowed_token_ids(
        batch_size=batch_size,
        vocab_size=vocab_size,
        num_allowed_token_ids=num_allowed_token_ids,
        device=logits.device,
    )
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[], [], []],
        spec_token_ids=spec_tokens,
        allowed_token_ids_mask=mask,
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
834
835
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
836
837
838
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
839
        logits=logits,
840
841
842
843
844
845
846
847
        sampling_metadata=metadata,
    )

    expected = torch.tensor(
        [[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
        dtype=torch.int,
        device=logits.device,
    )
848
    assert torch.equal(output.sampled_token_ids, expected)
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905


@pytest.mark.parametrize("batch_size", [1, 100])
@pytest.mark.parametrize("vocab_size", [100, 8192, 10000])
@pytest.mark.parametrize("max_spec_len", [1, 3])
@pytest.mark.parametrize("no_draft_probs", [True, False])
def test_sample_recovered_tokens(
    batch_size: int, vocab_size: int, max_spec_len: int, no_draft_probs: bool
):
    num_tokens = batch_size * max_spec_len

    # Create random draft probabilities.
    draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Create random target probabilities.
    target_logits = torch.rand(
        num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
    )
    target_probs = F.softmax(target_logits, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)

    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
    generators = {
        i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
    }
    sampling_metadata = create_sampling_metadata(
        all_greedy=False, temperature=temperature, generators=generators
    )

    spec_decode_metadata = create_spec_decode_metadata(
        draft_token_ids.reshape(batch_size, max_spec_len).tolist(), target_logits
    )

    ref_recovered_token_ids = native_sample_recovered_tokens(
        max_spec_len,
        spec_decode_metadata.num_draft_tokens,
        spec_decode_metadata.cu_num_draft_tokens,
        draft_token_ids,
        None if no_draft_probs else draft_probs,
        target_probs,
        sampling_metadata,
        device=DEVICE,
    )
    recovered_token_ids = sample_recovered_tokens(
        max_spec_len,
        spec_decode_metadata.num_draft_tokens,
        spec_decode_metadata.cu_num_draft_tokens,
        draft_token_ids,
        None if no_draft_probs else draft_probs,
        target_probs,
        sampling_metadata,
        device=DEVICE,
    )
    assert torch.equal(recovered_token_ids, ref_recovered_token_ids)