test_rejection_sampler.py 24.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for rejection sampling."""

5
import pytest
6
7
8
9
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
10
from vllm.model_executor.utils import set_random_seed
zhuwenwen's avatar
zhuwenwen committed
11
from vllm.platforms import current_platform
12

13
14
15
16
17
18
19
20

@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    This file tests V0 internals, so set VLLM_USE_V1=0.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')

21

22
23
24
25
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

26
27
28
29
30
31
32
33
34
35
36

def mock_causal_accepted_tensor(
        k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
    """Generate an "accepted" tensor which should yield causally-accepted tokens
    up to last accepted indices.

    Tokens after last_accepted_indices+1 may also be accepted, although they
    will not be causally accepted.
    """
    batch_size = last_accepted_indices.shape[0]

37
38
    accepted = (torch.arange(k).expand(batch_size, k)
                <= last_accepted_indices.unsqueeze(-1).broadcast_to(
39
                    batch_size, k))
40
41
42
43

    # Sprinkle accepted values after the contiguous initial accepted values.
    # This replicates the behavior of rejection sampling, which may "accept"
    # a token that cannot be accepted because of causality.
44
45
46
47
    sprinkle_candidates = (torch.arange(k).expand(
        batch_size,
        k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) +
                           1)
48
    sprinkle = torch.rand(batch_size, k) > 0.5
49
50
51
52
53
54
55
56
    accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
    return accepted


@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize(
    "which_tokens_accepted",
    ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
57
@pytest.mark.parametrize("device", CUDA_DEVICES)
zhuwenwen's avatar
zhuwenwen committed
58
@pytest.mark.parametrize("use_flashinfer", [True, False]  if not current_platform.is_rocm() else [False])
59
@torch.inference_mode()
60
def test_correct_output_format(which_tokens_accepted: str, seed: int,
61
                               device: str, use_flashinfer: bool):
62
63
64
    """Verify the output has correct format given predetermined accepted matrix.
    """
    set_random_seed(seed)
65
    torch.set_default_device(device)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    batch_size = 10
    k = 5
    vocab_size = 3000

    if which_tokens_accepted == "all_tokens_accepted":
        accepted = mock_causal_accepted_tensor(
            k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
    elif which_tokens_accepted == "no_tokens_accepted":
        accepted = mock_causal_accepted_tensor(
            k, -torch.ones((batch_size, ), dtype=torch.long))
    elif which_tokens_accepted == "some_tokens_accepted":
        last_accepted_indices = torch.randint(low=-1,
                                              high=k,
                                              size=(batch_size, ))
        accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
    else:
        raise AssertionError()

    recovered_token_ids = torch.randint(low=0,
                                        high=vocab_size,
                                        size=(batch_size, k),
88
                                        dtype=torch.int64)
89
90
91
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
92
                                    dtype=torch.int64)
93
94
95
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
96
                                    dtype=torch.int64)
97

98
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
99
    rejection_sampler.init_gpu_tensors(device=device)
100
101
102
103
104
105
106
    output_token_ids = rejection_sampler._create_output(  # pylint: disable=protected-access
        accepted,
        recovered_token_ids,
        draft_token_ids,
        bonus_token_ids,
    )

107
    expected_bonus_token_ids = bonus_token_ids.clone()
108

109
110
111
112
113
    if which_tokens_accepted == "all_tokens_accepted":
        # Expect all tokens to be equal to draft tokens.
        assert torch.equal(output_token_ids[:, :-1], draft_token_ids)

        # Expect all bonus tokens to be included.
114
        assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
115
116
117
118
119
120
121
122
123
    elif which_tokens_accepted == "no_tokens_accepted":
        # Expect first token to be equal to recovered tokens.
        assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])

        # Expect everything else to be -1.
        assert torch.equal(output_token_ids[:, 1:],
                           torch.ones_like(output_token_ids[:, 1:]) * -1)
    elif which_tokens_accepted == "some_tokens_accepted":
        recovered_plus_bonus = torch.cat(
124
            (recovered_token_ids, expected_bonus_token_ids), dim=-1)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        # Assert first rejected token is a recovered token or bonus token.
        assert torch.equal(
            recovered_plus_bonus[torch.arange(0, batch_size),
                                 last_accepted_indices + 1],
            output_token_ids[torch.arange(0, batch_size),
                             last_accepted_indices + 1])

        # Assert every subsequent token is -1.
        subsequent_mask = torch.arange(0, k + 1).expand(
            batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
        assert torch.all(output_token_ids[subsequent_mask] == -1)


@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
141
@pytest.mark.parametrize("device", CUDA_DEVICES)
zhuwenwen's avatar
zhuwenwen committed
142
@pytest.mark.parametrize("use_flashinfer", [True, False]  if not current_platform.is_rocm() else [False])
143
@torch.inference_mode()
144
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
145
                                    device: str, use_flashinfer: bool):
146
    torch.set_default_device(device)
147
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
148
    rejection_sampler.init_gpu_tensors(device=device)
149

150
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
151
152
153
154
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
155
156
157
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
158
                                    dtype=torch.int64)
159
160
161
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
162
                                    dtype=torch.int64)
163
164

    rejection_sampler(target_probs, bonus_token_ids, draft_probs,
165
                      draft_token_ids)
166
167
168
169
170
171
172
173


@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
174
175
176
177
# @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
178
179
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
180
181
                                   frac_seeded: float, n_rep: int, device: str,
                                   use_flashinfer: bool):
182
    torch.set_default_device(device)
183
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
184
    rejection_sampler.init_gpu_tensors(device=device)
185
186

    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
187
188
189
190
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
191
192
193
194
195
196
197
198
199
200
201
202
203
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded

    results = []
    for _ in range(n_rep):
204
205
206
207
        seeded_seqs = {
            i: torch.Generator(device=device).manual_seed(i)
            for i in range(batch_size) if seeded_mask[i]
        }
208
209
        results.append(
            rejection_sampler(target_probs, bonus_token_ids, draft_probs,
210
                              draft_token_ids, seeded_seqs))
211
212
213
214
215

    for i in range(batch_size):
        if seeded_mask[i]:
            for j in range(1, n_rep):
                assert torch.equal(results[j][i], results[0][i])
216
217


zhuwenwen's avatar
zhuwenwen committed
218
@pytest.mark.skipif(current_platform.is_rocm(),
219
                    reason="Consistent with NV.")
220
221
222
223
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
224
225
226
227
# @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@torch.inference_mode()
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
                            device: str, use_flashinfer: bool):
    torch.set_default_device(device)
    set_random_seed(0)
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    single_batches = []
    for i in range(batch_size):
        single_batches.append((draft_probs[i].clone().unsqueeze(0),
                               draft_token_ids[i].clone().unsqueeze(0),
                               target_probs[i].clone().unsqueeze(0),
                               bonus_token_ids[i].clone().unsqueeze(0),
                               draft_token_ids[i].clone().unsqueeze(0)))

    set_random_seed(0)
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
    rejection_sampler.init_gpu_tensors(device=device)

    results = []
    seeded_seqs = {
        i: torch.Generator(device=device).manual_seed(i)
        for i in range(1, batch_size)  # 0 is seed None
    }
    batch_result = rejection_sampler(target_probs.clone(),
                                     bonus_token_ids.clone(),
                                     draft_probs.clone(),
                                     draft_token_ids.clone(), seeded_seqs)

    set_random_seed(0)

    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
    rejection_sampler.init_gpu_tensors(device=device)
    for i in range(batch_size):
        request_seeded_seqs = {
            0: torch.Generator(device=device).manual_seed(i)
        } if seeded_seqs.get(i) is not None else None
        (draft_probs, draft_token_ids, target_probs, bonus_token_ids,
         draft_token_ids) = single_batches[i]
        results.append(
            rejection_sampler(target_probs, bonus_token_ids, draft_probs,
                              draft_token_ids, request_seeded_seqs))
    for i in range(batch_size):
        assert torch.equal(batch_result[i], results[i].squeeze(0))


286
287
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Rocm platform does not support flashinfer.")
288
289
290
291
292
293
294
295
296
297
298
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
                                       batch_size: int, device: str):
    """
    Test the flashinfer and nonflashinfer backend generate 
    the same output metrics.
    """
299
300
301
302

    pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
                "the ability to pass in uniform samples.")

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    torch.set_default_device(device)
    torch.manual_seed(0)
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
                                    dtype=torch.int64)
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
                                    dtype=torch.int64)

    num_accepted_tokens = []
    num_emitted_tokens = []
    num_draft_tokens = []

    def get_seeded_seqs():
        return {
            i: torch.Generator(device=device).manual_seed(i)
            for i in range(batch_size)
        }

zhuwenwen's avatar
zhuwenwen committed
329
    for use_flashinfer in [True, False] if not current_platform.is_rocm() else [False]:
330
        rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        rejection_sampler.init_gpu_tensors(device=device)
        # We use seeded sequences to ensure the same tokens are accepted
        # for both flashinfer and nonflashinfer backends.
        seeded_seqs = get_seeded_seqs()
        rejection_sampler(target_probs, bonus_token_ids, draft_probs,
                          draft_token_ids, seeded_seqs)
        num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
        num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
        num_draft_tokens.append(rejection_sampler.num_draft_tokens)

    assert num_accepted_tokens[0] == num_accepted_tokens[1]
    assert num_emitted_tokens[0] == num_emitted_tokens[1]
    assert num_draft_tokens[0] == num_draft_tokens[1]


346
347
348
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
                         ["bonus_token_ids", "draft_token_ids"])
349
@pytest.mark.parametrize("device", CUDA_DEVICES)
zhuwenwen's avatar
zhuwenwen committed
350
@pytest.mark.parametrize("use_flashinfer", [True, False]  if not current_platform.is_rocm() else [False])
351
352
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
353
354
                               which_token_ids: str, device: str,
                               use_flashinfer: bool):
355
356
357
    k = 3
    batch_size = 5
    vocab_size = 30_000
358
    torch.set_default_device(device)
359

360
    rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
361
                                         strict_mode=True)
362
    rejection_sampler.init_gpu_tensors(device=device)
363

364
    draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
365
366
367
368
    target_probs = torch.rand(batch_size,
                              k + 1,
                              vocab_size,
                              dtype=torch.float32)
369
370
371
    bonus_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, 1),
372
                                    dtype=torch.int64)
373
374
375
    draft_token_ids = torch.randint(low=0,
                                    high=vocab_size,
                                    size=(batch_size, k),
376
                                    dtype=torch.int64)
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

    oob_token_ids = None
    if which_token_ids == "bonus_token_ids":
        oob_token_ids = bonus_token_ids
    elif which_token_ids == "draft_token_ids":
        oob_token_ids = draft_token_ids
    else:
        raise AssertionError()

    if above_or_below_vocab_range == "above":
        rogue_token_id = vocab_size + 1
    elif above_or_below_vocab_range == "below":
        rogue_token_id = -1
    else:
        raise AssertionError()

    oob_token_ids[0][0] = rogue_token_id

    with pytest.raises(AssertionError):
        rejection_sampler(target_probs, bonus_token_ids, draft_probs,
397
                          draft_token_ids)
398
399
400
401


@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
zhuwenwen's avatar
zhuwenwen committed
402
@pytest.mark.parametrize("use_flashinfer", [True, False]  if not current_platform.is_rocm() else [False])
403
404
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
405
        seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    """Verify rejection sampling approximates target distribution,
    despite sampling from a potentially distinct draft distribution.

    This is done by first creating a random target probability
    distribution and a random draft probability distribution. We then
    sample token ids from the rejection sampler using these draft
    and target distributions. The samples are used to estimate
    the output probability distribution, which we expect to approximate
    the target distribution.

    A basic distance metric is used to determine similarity between
    distributions.

    We expect that as we increase the number of samples,
    the distance between the observed distribution and the target
    distribution decreases. To measure this, we compare the distance
    of the observed distribution against both the target distribution
    and a uniform random distribution. We expect the distance between
    the observed distribution and the target distribution to improve
    much more than the distance improvement between the observed
    distribution and the random distribution.

    When draft_and_target_probs_equal=True, the draft and target
    probabilities are exactly equal. Rejection sampling should
    still work without any NaNs or exceptions.
    """
432
    torch.set_default_device("cpu")
433
434
435
    set_random_seed(seed)
    helper = _CorrectnessTestHelper(
        vocab_size=10,
436
        rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
437
438
439
440
441
442
    )

    draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
        draft_and_target_probs_equal)

    sample_sizes = [10, 100, 1_000, 10_000, 100_000]
443
444
    distance_wrt_reference: list[float] = []
    distance_wrt_target: list[float] = []
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

    for num_samples in sample_sizes:
        (reference_vs_rejsample_dist,
         target_vs_rejsample_dist) = helper.run_and_compare_distributions(
             draft_probs,
             target_probs,
             reference_probs,
             num_samples,
         )

        distance_wrt_reference.append(reference_vs_rejsample_dist)
        distance_wrt_target.append(target_vs_rejsample_dist)

        relative_change_in_distance_wrt_target = get_ratio_first_to_last(
            distance_wrt_target)
        relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
            distance_wrt_reference)

        print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
              f"{reference_vs_rejsample_dist=:.05f}")
        print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
              f"{relative_change_in_distance_wrt_reference=:.02f}")

    relative_change_in_distance_wrt_target = get_ratio_first_to_last(
        distance_wrt_target)
    relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
        distance_wrt_reference)

    expected_improvement_multiplier = 20
474
475
    assert (relative_change_in_distance_wrt_target
            > relative_change_in_distance_wrt_reference *
476
477
478
            expected_improvement_multiplier)


479
def get_ratio_first_to_last(elements: list[float]) -> float:
480
481
482
483
484
485
486
487
488
489
490
491
492
    return elements[0] / elements[-1]


class _CorrectnessTestHelper:
    """Class that packages together logic required for the unit-level
    rejection sampling correctness test.
    """

    def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
        self.rejection_sampler = rejection_sampler
        self.vocab_size = vocab_size
        self.vocab_range = (0, vocab_size)

493
        self.rejection_sampler.init_gpu_tensors(device=0)
494
495
496
497
498
499
500
501
502
503

        # Keep test simple, use k=1
        self.k = 1

        # Bonus tokens not used, but rejection sampler requires
        # correct shape.
        self.num_bonus_tokens = 1

    def generate_probs_for_test(
        self, draft_and_target_probs_equal: bool
504
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
505
506
507
508
        draft_probs, target_probs = (F.softmax(
            torch.rand(self.vocab_size, dtype=torch.float32),
            dim=-1,
        ) for _ in range(2))
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525

        num_reference_probs = 100
        reference_probs = F.softmax(
            torch.rand(num_reference_probs,
                       self.vocab_size,
                       dtype=torch.float32),
            dim=-1,
        )

        if draft_and_target_probs_equal:
            target_probs = draft_probs.clone()

        return draft_probs, target_probs, reference_probs

    def run_and_compare_distributions(self, draft_probs: torch.Tensor,
                                      target_probs: torch.Tensor,
                                      reference_probs: torch.Tensor,
526
                                      num_samples: int) -> tuple[float, float]:
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        # Sample using rejection sampling.
        rej_sample_probs = self._estimate_rejection_sampling_pdf(
            draft_probs, target_probs, num_samples)

        # Average distance from reference probs.
        reference_vs_rejsample_dist = torch.dist(
            reference_probs,
            rej_sample_probs).item() / reference_probs.shape[0]
        target_vs_rejsample_dist = torch.dist(target_probs,
                                              rej_sample_probs).item()

        return reference_vs_rejsample_dist, target_vs_rejsample_dist

    def _estimate_rejection_sampling_pdf(
        self,
        draft_probs: torch.Tensor,
        target_probs: torch.Tensor,
        num_samples: int,
    ) -> torch.Tensor:
        # Repeat draft probs num_samples times.
        draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
            num_samples, 1, 1)

550
        # Repeat target probs num_samples * (k + 1) times.
551
552
        # Rejection sampler requires bonus token probs, but they aren't used.
        target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
553
            num_samples, self.k + 1, 1)
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

        # Randomly sample draft token ids from draft probs.
        draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
                                            num_samples=1,
                                            replacement=True).reshape(
                                                num_samples, self.k)

        # Bonus tokens not used but required.
        bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
                                      dtype=torch.int64,
                                      device="cuda").repeat(num_samples, 1)

        # Get output tokens via rejection sampling.
        output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
                                                  bonus_token_ids.to("cuda"),
                                                  draft_probs.to("cuda"),
570
                                                  draft_token_ids.to("cuda"))
571
572
573
574
575
576
577
578
579
580
581
582

        # Remove bonus tokens
        output_token_ids = output_token_ids[:, :-1].flatten()

        # Estimate probability density function
        hist = torch.histogram(output_token_ids.to(dtype=torch.float,
                                                   device="cpu"),
                               bins=self.vocab_size,
                               range=self.vocab_range,
                               density=True)

        return hist.hist