test_rejection_sampler.py 32.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any
4
from unittest.mock import Mock
5
6
7

import pytest
import torch
8
import torch.nn.functional as F
9

10
from tests.v1.sample.utils import create_allowed_token_ids
11
from vllm.platforms import current_platform
12
from vllm.v1.sample.logits_processor import LogitsProcessors
13
from vllm.v1.sample.metadata import SamplingMetadata
14
15
16
17
18
from vllm.v1.sample.rejection_sampler import (
    PLACEHOLDER_TOKEN_ID,
    RejectionSampler,
    sample_recovered_tokens,
)
19
from vllm.v1.sample.sampler import Sampler, SamplerOutput
20
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
21

22
DEVICE_TYPE = current_platform.device_type
23

24
25

@pytest.fixture
26
def rejection_sampler():
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    mock_sampler = Mock(spec=Sampler)
    mock_sampler.logprobs_mode = "raw_logprobs"
    return RejectionSampler(mock_sampler)


def mock_sampler_output(
    rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
):
    rejection_sampler.sampler.return_value = SamplerOutput(
        sampled_token_ids=bonus_token_ids, logprobs_tensors=None
    )


def create_spec_decode_metadata(
    spec_tokens: list[list[int]], logits: torch.Tensor
) -> SpecDecodeMetadata:
    metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
    metadata.target_logits_indices = torch.arange(logits.shape[0])
    # Output bonus token ids are mocked, so the bonus logit indices should
    # be empty.
    metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
    return metadata
49
50


51
def create_logits_tensor(
52
53
    output_token_ids: list[list[int]],
    vocab_size: int = 100,
54
    token_idx_to_override: int | None = None,
55
) -> torch.Tensor:
56
    """Helper function to create logits tensor that
57
    will produce desired token ids on argmax"""
58
    token_ids = [tokens[:-1] for tokens in output_token_ids]
59
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
60
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE_TYPE)
61
62
63
64
65
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
66
67
    if token_idx_to_override:
        logits[:, token_idx_to_override] = 99.0
68
69
70
    return logits


71
def create_sampling_metadata(
72
    all_greedy: bool,
73
74
75
76
77
78
79
80
81
82
83
84
    output_token_ids: list[list[int]] | None = None,
    prompt_token_ids: torch.Tensor | None = None,
    spec_token_ids: torch.Tensor | None = None,
    temperature: torch.Tensor | None = None,
    top_k: torch.Tensor | None = None,
    top_p: torch.Tensor | None = None,
    generators: dict[int, Any] | None = None,
    frequency_penalties: list[float] | None = None,
    presence_penalties: list[float] | None = None,
    repetition_penalties: list[float] | None = None,
    bad_words_token_ids: dict[int, list[list[int]]] | None = None,
    allowed_token_ids_mask: torch.Tensor | None = None,
85
) -> SamplingMetadata:
86
    """Create a v1 sampling metadata object with all_greedy set
87
88
    to the given value. Either all greedy or all random sampling
    is used.
89
90
    """
    generators = generators or {}
91
92
93
94
95
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

96
97
98
99
100
101
    if any([frequency_penalties, presence_penalties, repetition_penalties]):
        no_penalties = False

        assert output_token_ids
        assert len(output_token_ids) > 0

102
103
104
        frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE_TYPE)
        presence_penalties = torch.tensor(presence_penalties, device=DEVICE_TYPE)
        repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE_TYPE)
105
106
107
108
109
110
    else:
        no_penalties = True
        frequency_penalties = torch.tensor([])
        presence_penalties = torch.tensor([])
        repetition_penalties = torch.tensor([])

111
    return SamplingMetadata(
112
        temperature=temperature,
113
114
        all_greedy=all_greedy,
        all_random=not all_greedy,
115
116
        top_p=top_p,
        top_k=top_k,
117
        generators=generators,
118
        max_num_logprobs=None,
119
120
121
122
123
124
125
126
127
        no_penalties=no_penalties,
        prompt_token_ids=prompt_token_ids,
        frequency_penalties=frequency_penalties,
        presence_penalties=presence_penalties,
        repetition_penalties=repetition_penalties,
        output_token_ids=[] if output_token_ids is None else output_token_ids,
        spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
        allowed_token_ids_mask=allowed_token_ids_mask,
        bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
128
        logitsprocs=LogitsProcessors(),
129
130
131
    )


132
########################### Tests for Greedy Sampling ###################
133
def test_perfect_match(rejection_sampler):
134
135
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
136
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
137

138
    metadata = create_sampling_metadata(all_greedy=True)
139
    logits = create_logits_tensor(output_tokens)
140
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
141
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
142

143
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
144
145
146
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
147
        logits=logits,
148
149
        sampling_metadata=metadata,
    )
150
    expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
151
    assert torch.equal(output.sampled_token_ids, expected)
152
153


154
def test_early_mismatch(rejection_sampler):
155
156
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
157
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
158

159
    metadata = create_sampling_metadata(all_greedy=True)
160
    logits = create_logits_tensor(output_tokens)
161
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
162
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
163

164
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
165
166
167
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
168
        logits=logits,
169
170
171
172
173
174
175
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
176
    assert torch.equal(output.sampled_token_ids, expected)
177
178


179
def test_multiple_sequences(rejection_sampler):
180
181
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
182
    output_tokens = [[1, 2, 5], [3, 4]]  # Two sequences with bonus tokens 5 and 4
183

184
    metadata = create_sampling_metadata(all_greedy=True)
185
    logits = create_logits_tensor(output_tokens)
186
    bonus_token_tensor = torch.tensor(
187
188
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
189
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
190

191
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
192
193
194
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
195
        logits=logits,
196
197
        sampling_metadata=metadata,
    )
198
199
200
    expected = torch.tensor(
        [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
    )
201
    assert torch.equal(output.sampled_token_ids, expected)
202
203


204
def test_single_token_sequence(rejection_sampler):
205
206
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
207
    output_tokens = [[1, 2]]  # Single token with bonus token 2
208

209
    metadata = create_sampling_metadata(all_greedy=True)
210
    logits = create_logits_tensor(output_tokens)
211
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
212
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
213

214
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
215
216
217
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
218
        logits=logits,
219
220
        sampling_metadata=metadata,
    )
221
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
222
    assert torch.equal(output.sampled_token_ids, expected)
223
224


225
def test_empty_sequence(rejection_sampler):
226
    """Test handling empty sequence of speculated tokens"""
227
    spec_tokens: list[list[int]] = [[]]
228
    output_tokens = [[5]]  # Just the bonus token
229

230
    metadata = create_sampling_metadata(all_greedy=True)
231
    logits = create_logits_tensor(output_tokens)
232
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
233
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
234

235
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
236
237
238
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
239
        logits=logits,
240
241
        sampling_metadata=metadata,
    )
242
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
243
    assert torch.equal(output.sampled_token_ids, expected)
244
245


246
def test_multiple_mismatches(rejection_sampler):
247
248
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
249
    output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]]  # Mismatches in both sequences
250

251
    metadata = create_sampling_metadata(all_greedy=True)
252
    logits = create_logits_tensor(output_tokens)
253
    bonus_token_tensor = torch.tensor(
254
255
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
    )
256
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
257

258
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
259
260
261
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
262
        logits=logits,
263
264
265
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
266
267
268
269
        [
            [1, 2, 7, PLACEHOLDER_TOKEN_ID],
            [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID],
        ],
270
271
272
        dtype=torch.int,
        device=logits.device,
    )
273
    assert torch.equal(output.sampled_token_ids, expected)
274
275
276
277
278


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
279
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
280
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
281
282
283
284
285
286
287
288
        (
            [[1, 2], [3, 4]],
            [[1, 5, 6], [3, 4, 7]],
            [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
        ),  # Mixed matches
    ],
)
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected):
289
    """Parametrized test for various matching scenarios"""
290
    metadata = create_sampling_metadata(all_greedy=True)
291
    logits = create_logits_tensor(output_tokens)
292
293
294
    bonus_token_tensor = torch.tensor(
        [tokens[-1] for tokens in output_tokens], device=logits.device
    )
295
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
296

297
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
298
299
300
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
301
        logits=logits,
302
303
        sampling_metadata=metadata,
    )
304
    expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
305
    assert torch.equal(output.sampled_token_ids, expected_tensor)
306
307
308
309
310
311
312
313


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
314
315
316
317
318
319
320
321
322
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
323
324
325
326
327
328
    draft_probs = torch.rand(
        num_tokens,
        vocab_size,
        dtype=torch.float32,
        device=DEVICE_TYPE,
    )
329
330
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
331
    bonus_token_ids = torch.randint(
332
333
334
335
336
        low=0,
        high=vocab_size,
        size=(batch_size, 1),
        dtype=torch.int64,
        device=DEVICE_TYPE,
337
338
    )
    draft_token_ids = torch.randint(
339
340
341
342
343
        low=0,
        high=vocab_size,
        size=(batch_size, k),
        dtype=torch.int64,
        device=DEVICE_TYPE,
344
    )
345
346
347
348
349
350

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
351
            i: torch.Generator(device=DEVICE_TYPE).manual_seed(i)
352
353
            for i in range(batch_size)
            if seeded_mask[i]
354
355
        }

356
        temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
357
358
359
        sampling_metadata = create_sampling_metadata(
            all_greedy=False, temperature=temperature, generators=seeded_seqs
        )
360
361
        spec_decode_metadata = create_spec_decode_metadata(
            draft_token_ids.tolist(), target_logits
362
        )
363
364

        mock_sampler_output(rejection_sampler, bonus_token_ids)
365
366
        rep_result = rejection_sampler(
            spec_decode_metadata,
367
368
            draft_probs=None,
            logits=target_logits,
369
370
            sampling_metadata=sampling_metadata,
        )
371

372
        results.append(rep_result.sampled_token_ids)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
403
    torch.set_default_device(DEVICE_TYPE)
404
405
406
407
408
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
409
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1)
410
411
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
412
413
414
415
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
416

417
418
419
420
421
422
423
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
424
425
            draft_probs, target_logits, k, vocab_size, num_samples
        )
426
        rej_sample_probs = rej_sample_probs.to(DEVICE_TYPE)
427
428

        # Average distance from reference probs.
429
430
431
432
433
        reference_vs_rejsample_dist = (
            torch.dist(reference_probs, rej_sample_probs).item()
            / reference_probs.shape[0]
        )
        target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item()
434
435
436
437
438

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
439
440
            distance_wrt_target
        )
441
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
442
443
            distance_wrt_reference
        )
444

445
446
447
448
449
450
451
452
        print(
            f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
            f"{reference_vs_rejsample_dist=:.05f}"
        )
        print(
            f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
            f"{relative_change_in_distance_wrt_reference=:.02f}"
        )
453
454

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
455
456
        distance_wrt_target
    )
457
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
458
459
        distance_wrt_reference
    )
460
461

    expected_improvement_multiplier = 20
462
463
464
465
    assert (
        relative_change_in_distance_wrt_target
        > relative_change_in_distance_wrt_reference * expected_improvement_multiplier
    )
466
467
468
469
470
471
472
473


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
474
    target_logits: torch.Tensor,
475
476
477
478
479
480
481
482
483
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
484
        target_logits: Target logits.
485
486
487
488
489
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
490
491
492
    mock_sampler = Mock(spec=Sampler)
    mock_sampler.logprobs_mode = "raw_logprobs"
    rejection_sampler = RejectionSampler(mock_sampler)
493
494
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
495
    draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
496

497
498
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
499
500

    # Randomly sample draft token ids from draft probs.
501
502
503
    draft_token_ids = torch.multinomial(
        draft_probs[:, 0, :], num_samples=k, replacement=True
    ).reshape(num_samples, k)
504
    draft_probs = draft_probs.view(num_tokens, vocab_size)
505
506

    # Bonus tokens not used but required.
507
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE_TYPE).repeat(
508
509
        num_samples, 1
    )
510

511
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE_TYPE)
512
513
514
    sampling_metadata = create_sampling_metadata(
        all_greedy=False, temperature=temperature
    )
515
516
    spec_decode_metadata = create_spec_decode_metadata(
        draft_token_ids.tolist(), target_logits
517
    )
518
519
520

    mock_sampler_output(rejection_sampler, bonus_token_ids)
    sampler_output = rejection_sampler(
521
522
        spec_decode_metadata,
        draft_probs=draft_probs,
523
        logits=target_logits,
524
525
        sampling_metadata=sampling_metadata,
    )
526
    output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
527

528
529
530
531
532
533
    hist = torch.histogram(
        output_token_ids.to(dtype=torch.float, device="cpu"),
        bins=vocab_size,
        range=(0, vocab_size),
        density=True,
    )
534
535

    return hist.hist
536
537


538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
def native_sample_recovered_tokens(
    max_spec_len: int,
    num_draft_tokens: list[int],
    cu_num_draft_tokens: torch.Tensor,  # [batch_size]
    draft_token_ids: torch.Tensor,  # [num_tokens]
    draft_probs: torch.Tensor | None,  # [num_tokens, vocab_size]
    target_probs: torch.Tensor,  # [num_tokens, vocab_size]
    sampling_metadata: SamplingMetadata,
    device: torch.device,
) -> torch.Tensor:
    batch_size = len(num_draft_tokens)
    vocab_size = target_probs.shape[-1]

    q = torch.empty(
        (batch_size, vocab_size),
        dtype=torch.float32,
        device=device,
    )
    q.exponential_()

    states = {
        i: generator.get_state()
        for i, generator in sampling_metadata.generators.items()
    }
    for i, generator in sampling_metadata.generators.items():
        # Do not generate random numbers for requests with no draft tokens.
        # This can be important for reproducibility.
        if num_draft_tokens[i] > 0:
            q[i].exponential_(generator=generator)

        # In order to generate the same exponential later, reset the CUDA RNG
        # state because RNG state advances after each call.
        generator.set_state(states[i])

    inv_q = q.reciprocal()

    out = torch.empty_like(draft_token_ids)

    for req_idx in range(batch_size):
        start_idx = 0 if req_idx == 0 else int(cu_num_draft_tokens[req_idx - 1].item())
        end_idx = int(cu_num_draft_tokens[req_idx].item())
        num_tokens = end_idx - start_idx

        for pos in range(max_spec_len):
            if pos >= num_tokens:
                continue
            token_idx = start_idx + pos

            if draft_probs is None:
                # prob is target_probs[token_idx] except draft_token_id is zeroed
                prob = target_probs[token_idx].clone()
                draft_token_id = draft_token_ids[token_idx]
                prob[draft_token_id] = 0.0
            else:
                prob = (target_probs[token_idx] - draft_probs[token_idx]).clamp_min_(
                    0.0
                )

            score = prob * inv_q[req_idx]
            recovered_id = torch.argmax(score, dim=-1)
            out[token_idx] = recovered_id
    return out


602
603
604
605
606
607
608
609
610
611
612
613
614
def _test_masked_logits(
    rejection_sampler,
    batch_size: int,
    num_draft_tokens: int,
    vocab_size: int,
    target_logits: torch.Tensor,
    unmasked_indices: torch.Tensor,
    sampling_metadata: SamplingMetadata,
):
    # Set up test parameters
    num_tokens = batch_size * num_draft_tokens

    # Create random draft probabilities.
615
    draft_probs = torch.rand(
616
        (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE_TYPE
617
    )
618
619
620
621
622
623
624
625
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
    draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
    draft_token_ids = draft_token_ids.tolist()

    # Bonus tokens not used but required
626
627
628
629
630
    bonus_token_ids = torch.zeros(
        (batch_size, 1),
        dtype=torch.int64,
        device=DEVICE_TYPE,
    )
631
632

    # Create spec decode metadata
633
    spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
634
635

    # Run rejection sampling
636
637
    mock_sampler_output(rejection_sampler, bonus_token_ids)
    output = rejection_sampler(
638
639
        spec_decode_metadata,
        draft_probs=draft_probs,
640
        logits=target_logits,
641
642
643
644
        sampling_metadata=sampling_metadata,
    )

    # Remove bonus tokens and reshape
645
    output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664

    # Check that all sampled tokens are within the unmasked indices.
    for i in range(num_tokens):
        token_id = output_token_ids[i]
        if token_id == PLACEHOLDER_TOKEN_ID:
            continue
        assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
    """Test rejection sampling with top-k sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Randomly create top-k indices.
    top_k_indices = [
665
666
        torch.randperm(vocab_size, device=DEVICE_TYPE)[:top_k]
        for _ in range(num_tokens)
667
668
669
670
    ]
    top_k_indices = torch.stack(top_k_indices)

    # Create logits with the uniform distribution.
671
    target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE_TYPE)
672
673
674
675
676
677
678
679

    # Increment the logits for top-k indices, a little bit more than the other
    # ones. If the masking is effective, the non-topk indices will never be
    # sampled despite the small difference in logits.
    for i in range(num_tokens):
        target_logits[i, top_k_indices[i]] += 0.1

    # Create sampling metadata
680
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
681
682
683
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
684
        top_k=torch.tensor([top_k] * batch_size, device=DEVICE_TYPE, dtype=torch.int64),
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_k_indices,
        sampling_metadata=sampling_metadata,
    )


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
    """Test rejection sampling with top-p sampling"""
    vocab_size = 100
    batch_size = 100
    num_draft_tokens = 3
    num_tokens = batch_size * num_draft_tokens

    # Create logits with the uniform distribution.
707
708
    target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE_TYPE)
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    rescaled_logits = target_logits / temperature

    logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - top_p
    # at least one
    top_p_mask[:, -1] = False

    # Get the top-p indices.
    top_p_indices = []
    for i in range(num_tokens):
        top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

    # Create sampling metadata
    sampling_metadata = create_sampling_metadata(
        all_greedy=False,
        temperature=temperature,
727
728
729
730
731
        top_p=torch.tensor(
            [top_p] * batch_size,
            device=DEVICE_TYPE,
            dtype=torch.float32,
        ),
732
733
734
735
736
737
738
739
740
741
742
    )

    _test_masked_logits(
        rejection_sampler,
        batch_size=batch_size,
        num_draft_tokens=num_draft_tokens,
        vocab_size=vocab_size,
        target_logits=target_logits,
        unmasked_indices=top_p_indices,
        sampling_metadata=sampling_metadata,
    )
743
744
745
746
747
748
749
750


########################### Tests for Logit Processors ###################
def test_frequency_penalties(rejection_sampler):
    """Test rejection sampling with frequency penalties"""
    spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
    output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]]  # 1, 7 and 1 are the bonus tokens

Jiayi Yan's avatar
Jiayi Yan committed
751
    num_requests = len(spec_tokens)
752
753
754
755
756
    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[2], [3], [4]],
        spec_token_ids=spec_tokens,
757
758
759
760
        prompt_token_ids=torch.tensor(
            [[5, 6, 7], [6, 7, 8], [7, 8, 9]],
            device=DEVICE_TYPE,
        ),
761
        frequency_penalties=[1.5, 1.5, 0.7],
Jiayi Yan's avatar
Jiayi Yan committed
762
763
        presence_penalties=[0.0] * num_requests,
        repetition_penalties=[1.0] * num_requests,
764
765
766
767
768
769
770
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        spec_tokens, device=logits.device
    )
771
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
772
773
774
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
775
        logits=logits,
776
777
778
779
780
781
782
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
        dtype=torch.int,
        device=logits.device,
    )
783
    assert torch.equal(output.sampled_token_ids, expected)
784
785
786


def test_bad_words(rejection_sampler):
787
788
789
790
791
    """Test rejection sampling with bad words constraints.

    This test applies bad words to non-consecutive requests (0 and 2, but not 1)
    to verify correct logit indexing when iterating over requests with bad words.
    """
792
    spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
793
    output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]]
794
795
796
797
798
799
800

    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[2], [3], [4]],
        spec_token_ids=spec_tokens,
        bad_words_token_ids={
801
802
803
            0: [[2]],
            # Request 1 has no bad words (to test non-consecutive request handling)
            2: [[2]],
804
805
806
807
808
        },
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
809
810
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
811
812
813
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
814
        logits=logits,
815
816
817
        sampling_metadata=metadata,
    )

818
819
820
    # Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
    # Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
    # Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
821
    expected = torch.tensor(
822
        [[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]],
823
824
825
        dtype=torch.int,
        device=logits.device,
    )
826
    assert torch.equal(output.sampled_token_ids, expected)
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858


def test_allowed_token_ids(rejection_sampler):
    """Test rejection sampling with allowed token ids"""
    spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
    output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
    # Not allowed tokens:
    # 0: 0-4
    # 1: 1-5
    # 2: 2-6
    num_allowed_token_ids = 5

    # Use the token 15 as the sampler choose if a token rejected
    logits = create_logits_tensor(output_tokens, token_idx_to_override=15)

    batch_size = len(output_tokens)
    _, vocab_size = logits.size()
    mask = create_allowed_token_ids(
        batch_size=batch_size,
        vocab_size=vocab_size,
        num_allowed_token_ids=num_allowed_token_ids,
        device=logits.device,
    )
    metadata = create_sampling_metadata(
        all_greedy=True,
        output_token_ids=[[], [], []],
        spec_token_ids=spec_tokens,
        allowed_token_ids_mask=mask,
    )
    bonus_token_tensor = torch.tensor(
        [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
    )
859
860
    spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
    mock_sampler_output(rejection_sampler, bonus_token_tensor)
861
862
863
    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
864
        logits=logits,
865
866
867
868
869
870
871
872
        sampling_metadata=metadata,
    )

    expected = torch.tensor(
        [[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
        dtype=torch.int,
        device=logits.device,
    )
873
    assert torch.equal(output.sampled_token_ids, expected)
874
875
876
877
878
879
880
881
882
883
884
885


@pytest.mark.parametrize("batch_size", [1, 100])
@pytest.mark.parametrize("vocab_size", [100, 8192, 10000])
@pytest.mark.parametrize("max_spec_len", [1, 3])
@pytest.mark.parametrize("no_draft_probs", [True, False])
def test_sample_recovered_tokens(
    batch_size: int, vocab_size: int, max_spec_len: int, no_draft_probs: bool
):
    num_tokens = batch_size * max_spec_len

    # Create random draft probabilities.
886
887
888
889
890
891
    draft_probs = torch.rand(
        num_tokens,
        vocab_size,
        dtype=torch.float32,
        device=DEVICE_TYPE,
    )
892
893
894
895
    draft_probs = F.softmax(draft_probs, dim=-1)

    # Create random target probabilities.
    target_logits = torch.rand(
896
        num_tokens, vocab_size, dtype=torch.float32, device=DEVICE_TYPE
897
898
899
900
901
902
    )
    target_probs = F.softmax(target_logits, dim=-1)

    # Randomly sample draft token ids from draft probs
    draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)

903
    temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
904
    generators = {
905
        i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
    }
    sampling_metadata = create_sampling_metadata(
        all_greedy=False, temperature=temperature, generators=generators
    )

    spec_decode_metadata = create_spec_decode_metadata(
        draft_token_ids.reshape(batch_size, max_spec_len).tolist(), target_logits
    )

    ref_recovered_token_ids = native_sample_recovered_tokens(
        max_spec_len,
        spec_decode_metadata.num_draft_tokens,
        spec_decode_metadata.cu_num_draft_tokens,
        draft_token_ids,
        None if no_draft_probs else draft_probs,
        target_probs,
        sampling_metadata,
923
        device=DEVICE_TYPE,
924
925
926
927
928
929
930
931
932
    )
    recovered_token_ids = sample_recovered_tokens(
        max_spec_len,
        spec_decode_metadata.num_draft_tokens,
        spec_decode_metadata.cu_num_draft_tokens,
        draft_token_ids,
        None if no_draft_probs else draft_probs,
        target_probs,
        sampling_metadata,
933
        device=DEVICE_TYPE,
934
935
    )
    assert torch.equal(recovered_token_ids, ref_recovered_token_ids)