test_rejection_sampler.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from typing import Any, Optional
3
4
5

import pytest
import torch
6
import torch.nn.functional as F
7
8

from vllm.v1.sample.metadata import SamplingMetadata
9
10
11
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
                                              RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
12

13
DEVICE = "cuda"
14

15
16

@pytest.fixture
17
def rejection_sampler():
18
19
20
    return RejectionSampler()


21
def create_logits_tensor(output_token_ids: list[list[int]],
22
23
24
                         vocab_size: int = 100) -> torch.Tensor:
    """Helper function to create logits tensor that 
       will produce desired token ids on argmax"""
25
    token_ids = [tokens[:-1] for tokens in output_token_ids]
26
27
28
29
30
31
32
    num_total_tokens = sum(len(tokens) for tokens in token_ids)
    logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
    start_loc = 0
    for tokens in token_ids:
        for j, token_id in enumerate(tokens):
            logits[start_loc + j, token_id] = 100.0
        start_loc += len(tokens)
33
34
35
    return logits


36
def create_sampling_metadata(
37
38
39
40
    all_greedy: bool,
    temperature: Optional[torch.Tensor] = None,
    generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
41
42
43
44
45
    """Create a v1 sampling metadata object with all_greedy set 
        to the given value. Either all greedy or all random sampling 
        is used.
    """
    generators = generators or {}
46
47
48
49
50
    if all_greedy:
        temperature = None
    else:
        assert temperature is not None

51
    return SamplingMetadata(
52
        temperature=temperature,
53
54
        all_greedy=all_greedy,
        all_random=not all_greedy,
55
56
        top_p=None,
        top_k=None,
57
58
        min_p=torch.empty(1, ),
        generators=generators,
59
60
61
62
63
64
65
        max_num_logprobs=0,
        no_penalties=False,
        prompt_token_ids=None,
        frequency_penalties=torch.tensor([]),
        presence_penalties=torch.tensor([]),
        repetition_penalties=torch.tensor([]),
        output_token_ids=[],
66
        min_tokens={},
67
        logit_bias=[None],
68
        allowed_token_ids_mask=None,
69
        bad_words_token_ids={},
70
71
72
    )


73
########################### Tests for Greedy Sampling ###################
74
def test_perfect_match(rejection_sampler):
75
76
    """Test when output tokens perfectly match speculated tokens"""
    spec_tokens = [[1, 2, 3]]
77
    output_tokens = [[1, 2, 3, 4]]  # 4 is the bonus token
78

79
    metadata = create_sampling_metadata(all_greedy=True)
80
    logits = create_logits_tensor(output_tokens)
81
82
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
83
84
85
86
87
88
89
90
91
92
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
93
94
95
    expected = torch.tensor([[1, 2, 3, 4]],
                            dtype=torch.int,
                            device=logits.device)
96
    assert torch.equal(output, expected)
97
98


99
def test_early_mismatch(rejection_sampler):
100
101
    """Test when there's an early mismatch in tokens"""
    spec_tokens = [[1, 2, 3]]
102
    output_tokens = [[1, 5, 3, 4]]  # Mismatch at position 1
103

104
    metadata = create_sampling_metadata(all_greedy=True)
105
    logits = create_logits_tensor(output_tokens)
106
107
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
123
    assert torch.equal(output, expected)
124
125


126
def test_multiple_sequences(rejection_sampler):
127
128
    """Test handling multiple sequences of speculated tokens"""
    spec_tokens = [[1, 2], [3]]
129
130
    output_tokens = [[1, 2, 5], [3,
                                 4]]  # Two sequences with bonus tokens 5 and 4
131

132
    metadata = create_sampling_metadata(all_greedy=True)
133
    logits = create_logits_tensor(output_tokens)
134
135
    bonus_token_tensor = torch.tensor(
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
136
137
138
139
140
141
142
143
144
145
146
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
147
148
                            dtype=torch.int,
                            device=logits.device)
149
    assert torch.equal(output, expected)
150
151


152
def test_single_token_sequence(rejection_sampler):
153
154
    """Test handling sequences with single token"""
    spec_tokens = [[1]]
155
    output_tokens = [[1, 2]]  # Single token with bonus token 2
156

157
    metadata = create_sampling_metadata(all_greedy=True)
158
    logits = create_logits_tensor(output_tokens)
159
160
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
161
162
163
164
165
166
167
168
169
170
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
171
    expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
172
    assert torch.equal(output, expected)
173
174


175
def test_empty_sequence(rejection_sampler):
176
    """Test handling empty sequence of speculated tokens"""
177
    spec_tokens: list[list[int]] = [[]]
178
    output_tokens = [[5]]  # Just the bonus token
179

180
    metadata = create_sampling_metadata(all_greedy=True)
181
    logits = create_logits_tensor(output_tokens)
182
183
    bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
                                      device=logits.device)
184
185
186
187
188
189
190
191
192
193
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
194
    expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
195
    assert torch.equal(output, expected)
196
197


198
def test_multiple_mismatches(rejection_sampler):
199
200
    """Test handling multiple sequences with mismatches"""
    spec_tokens = [[1, 2, 3], [4, 5, 6]]
201
202
    output_tokens = [[1, 2, 7, 6], [4, 8, 6,
                                    9]]  # Mismatches in both sequences
203

204
    metadata = create_sampling_metadata(all_greedy=True)
205
    logits = create_logits_tensor(output_tokens)
206
207
    bonus_token_tensor = torch.tensor(
        [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
    expected = torch.tensor(
        [[1, 2, 7, PLACEHOLDER_TOKEN_ID],
         [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
        dtype=torch.int,
        device=logits.device,
    )
224
    assert torch.equal(output, expected)
225
226
227
228
229


@pytest.mark.parametrize(
    "spec_tokens,output_tokens,expected",
    [
230
        ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]),  # Perfect match with bonus
231
        ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]),  # First mismatch
232
        ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
233
         [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]),  # Mixed matches
234
    ])
235
236
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
                            expected):
237
    """Parametrized test for various matching scenarios"""
238
    metadata = create_sampling_metadata(all_greedy=True)
239
    logits = create_logits_tensor(output_tokens)
240
241
    bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
                                      device=logits.device)
242
243
244
245
246
247
248
249
250
251
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
                                                         device=logits.device)

    output = rejection_sampler(
        spec_decode_metadata,
        draft_probs=None,
        target_logits=logits,
        bonus_token_ids=bonus_token_tensor,
        sampling_metadata=metadata,
    )
252
253
254
    expected_tensor = torch.tensor(expected,
                                   dtype=torch.int,
                                   device=logits.device)
255
256
257
258
259
260
261
262
263
    assert torch.equal(output, expected_tensor)


########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def test_deterministic_when_seeded(
    rejection_sampler,
    k: int,
    vocab_size: int,
    batch_size: int,
    frac_seeded: float,
    n_rep: int,
):
    num_tokens = batch_size * k
    draft_probs = torch.rand(num_tokens,
                             vocab_size,
                             dtype=torch.float32,
                             device=DEVICE)
    draft_probs = F.softmax(draft_probs, dim=-1)
    target_logits = torch.rand_like(draft_probs)
279
280
281
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
282
283
                                    dtype=torch.int64,
                                    device=DEVICE)
284
285
286
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
287
288
                                    dtype=torch.int64,
                                    device=DEVICE)
289
290
291
292
293
294
295
296
297
298

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
        seeded_seqs = {
            i: torch.Generator(device=DEVICE).manual_seed(i)
            for i in range(batch_size) if seeded_mask[i]
        }

299
300
301
        temperature = torch.ones(batch_size,
                                 dtype=torch.float32,
                                 device=DEVICE)
302
        sampling_metadata = create_sampling_metadata(all_greedy=False,
303
                                                     temperature=temperature,
304
                                                     generators=seeded_seqs)
305
306
307
308
309
310
311
312
313
        spec_decode_metadata = SpecDecodeMetadata.make_dummy(
            draft_token_ids.tolist(), device=DEVICE)
        rep_result = rejection_sampler(
            spec_decode_metadata,
            draft_probs=draft_probs,
            target_logits=target_logits,
            bonus_token_ids=bonus_token_ids,
            sampling_metadata=sampling_metadata,
        )
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

        results.append(rep_result)

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])


def test_rejection_sampling_approximates_target_distribution():
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.
    """
    torch.set_default_device(DEVICE)
    vocab_size = 10
    k = 2
    num_reference_probs = 100

    # Prepare draft, target, and reference probability distributions
352
353
354
355
    draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
                            dim=-1)
    target_logits = torch.rand(vocab_size, dtype=torch.float32)
    target_probs = F.softmax(target_logits, dim=-1)
356
357
358
359
    reference_probs = F.softmax(
        torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
        dim=-1,
    )
360

361
362
363
364
365
366
367
    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []

    for num_samples in sample_sizes:
        # Sample using rejection sampling.
        rej_sample_probs = estimate_rejection_sampling_pdf(
368
            draft_probs, target_logits, k, vocab_size, num_samples)
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        rej_sample_probs = rej_sample_probs.to(DEVICE)

        # Average distance from reference probs.
        reference_vs_rejsample_dist = torch.dist(
            reference_probs,
            rej_sample_probs).item() / reference_probs.shape[0]
        target_vs_rejsample_dist = torch.dist(target_probs,
                                              rej_sample_probs).item()

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
            distance_wrt_target)
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
            distance_wrt_reference)

        print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
              f"{reference_vs_rejsample_dist=:.05f}")
        print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
              f"{relative_change_in_distance_wrt_reference=:.02f}")

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
        distance_wrt_target)
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
        distance_wrt_reference)

    expected_improvement_multiplier = 20
    assert (relative_change_in_distance_wrt_target
            > relative_change_in_distance_wrt_reference *
            expected_improvement_multiplier)


def get_ratio_first_to_last(elements: list[float]) -> float:
    return elements[0] / elements[-1]


def estimate_rejection_sampling_pdf(
    draft_probs: torch.Tensor,
408
    target_logits: torch.Tensor,
409
410
411
412
413
414
415
416
417
    k: int,
    vocab_size: int,
    num_samples: int,
) -> torch.Tensor:
    """Estimate the probability distribution of the output tokens
    using rejection sampling.

    Args:
        draft_probs: Draft probability distribution.
418
        target_logits: Target logits.
419
420
421
422
423
        num_samples: Number of samples to draw.

    Returns:
        Estimated probability distribution of the output tokens.
    """
424
425
426
    rejection_sampler = RejectionSampler()
    num_tokens = num_samples * k
    # Repeat draft probs num_samples * k times.
427
428
429
    draft_probs = draft_probs.reshape(1, 1,
                                      vocab_size).repeat(num_samples, k, 1)

430
431
    # Repeat target probs num_tokens times.
    target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
432
433
434
435
436
437

    # Randomly sample draft token ids from draft probs.
    draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
                                        num_samples=k,
                                        replacement=True).reshape(
                                            num_samples, k)
438
    draft_probs = draft_probs.view(num_tokens, vocab_size)
439
440
441
442
443

    # Bonus tokens not used but required.
    bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
                                  device=DEVICE).repeat(num_samples, 1)

444
445
446
447
448
449
450
451
452
453
454
455
    temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
    sampling_metadata = create_sampling_metadata(all_greedy=False,
                                                 temperature=temperature)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids.tolist(), device=bonus_token_ids.device)
    output_token_ids = rejection_sampler(
        spec_decode_metadata,
        draft_probs=draft_probs,
        target_logits=target_logits,
        bonus_token_ids=bonus_token_ids,
        sampling_metadata=sampling_metadata,
    )
456
457
458
459
460
461
462
463
464
    output_token_ids = output_token_ids[:, :-1].flatten()

    hist = torch.histogram(output_token_ids.to(dtype=torch.float,
                                               device="cpu"),
                           bins=vocab_size,
                           range=(0, vocab_size),
                           density=True)

    return hist.hist