test_sampler.py 23 KB
Newer Older
1
import itertools
2
import random
3
from typing import List, Optional, Tuple
4
5
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
6
import pytest
7
import torch
8
from transformers import GenerationConfig, GenerationMixin
9
10

from vllm.model_executor.layers.sampler import Sampler
11
from vllm.model_executor.sampling_metadata import SamplingMetadata
12
13
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
14
from vllm.utils import Counter, is_pin_memory_available
15
16
17
18


class MockLogitsSampler(Sampler):

19
20
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
21
22
23
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
24
        return super().forward(*args, **kwargs)
25
26
27


def _prepare_test(
28
29
        batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
30
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
31
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
32
33
                             1e-2,
                             dtype=input_tensor.dtype)
34
    sampler = MockLogitsSampler(fake_logits)
35
    return input_tensor, fake_logits, sampler
36
37


38
VOCAB_SIZE = 32000
39
RANDOM_SEEDS = list(range(128))
40
41
42
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
43
44


Nick Hill's avatar
Nick Hill committed
45
46
47
48
49
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    sampling_params: SamplingParams,
50
    device: str,
Nick Hill's avatar
Nick Hill committed
51
):
52
    seq_group_metadata_list = []
53
    seq_lens = []
54
55
56
57
58
59
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
60
                sampling_params=sampling_params,
61
62
                block_tables={0: [1]},
            ))
63
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
64

65
66
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
67
68
        seq_lens,
        query_lens=seq_lens,
69
        device=device,
70
        pin_memory=is_pin_memory_available())
71
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
72
73
74
75
76
77
78
79


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
80
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
81
82

    sampling_params = SamplingParams(temperature=0)
83
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
84
                                sampling_params, device)
85
86
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        for nth_output in sequence_output.samples:
88
89
90
91
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
92
93
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
94
    set_random_seed(seed)
95
    torch.set_default_device(device)
96
    batch_size = random.randint(1, 256)
97
    _, fake_logits, sampler = _prepare_test(batch_size)
98
99
100
101

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
102
103
104
105
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
106
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
107
                                sampling_params, device)
Nick Hill's avatar
Nick Hill committed
108
109
110
111
112
113
114
115
116
117
118
119

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
120
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
121

122
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
123
124
125
126
127
128
129
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
130
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
131
                                sampling_params, device)
132
133

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        for nth_output in sequence_output.samples:
135
136
137
            assert nth_output.output_token == i


Nick Hill's avatar
Nick Hill committed
138
139
140
141
142
143
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
144
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
145
146
147
148
149
150

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
151
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
152
                                      sampling_params, device)
Nick Hill's avatar
Nick Hill committed
153

154
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
155
                                       sampling_params, device)
Nick Hill's avatar
Nick Hill committed
156
157
158
159

    assert first_sampler_output == second_sampler_output


160
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
161
162
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
163
    set_random_seed(seed)
164
    torch.set_default_device(device)
165
    batch_size = random.randint(1, 256)
166
    _, fake_logits, sampler = _prepare_test(batch_size)
167

Nick Hill's avatar
Nick Hill committed
168
169
170
171
172
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
173
    _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
174
175
176
177
178
179
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.


180
181
182
183
184
185
186
187
188
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
189
                               *,
190
                               stop_token_ids: Optional[List[int]] = None,
191
                               prompt_logprobs: Optional[int] = None):
192
193
194
195
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
196
197
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
198
        )
199
        sampling_params.all_stop_token_ids.add(eos_token_id)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
        seq_data = SequenceData(
            random.choices(range(0, VOCAB_SIZE), k=num_input))
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
        sequence_metadata_list = []
216
217
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

            seq_data = {}
            seq_group_penalization = []
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
239
                num_generated = 0 if is_prompt else random.randint(1, 100)
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
367
                is_prompt=False,
368
                seq_data={
369
370
371
372
373
374
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
375
376
                },
                sampling_params=create_sampling_params(
377
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
378
                block_tables={},
379
            ),
380
381
382
383
384
385
386
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
387
            prompt_with_penalization_and_prompt_logprobs,
388
            stop_penalizing_after_min_tokens,
389
390
            prompt_combination,
            decode_combination,
391
392
393
394
395
396
397
        ]
    else:
        test_cases = [generate_test_case()]

    def run_test_case(*,
                      expected_penalization=None,
                      seq_group_metadata_list=None):
398
399
400
401
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
402
403

        batch_size = 0
404
        seq_lens = []
405
        sampling_params_per_row = []
406
407
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
408
409
410
411
412
413

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
414
                seq_lens.append(prompt_len)
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429

                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
430

431
        _, fake_logits, sampler = _prepare_test(batch_size)
432
        sampling_metadata = SamplingMetadata.prepare(
433
            seq_group_metadata_list,
434
435
            seq_lens=seq_lens if seq_lens else None,
            query_lens=seq_lens if seq_lens else None,
436
            device=device,
437
            pin_memory=is_pin_memory_available())
438
439
440
441
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
442
                zip(expected_penalization, sampling_params_per_row)):
443

444
            tokens_to_check = sampling_params.all_stop_token_ids
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

    for test_case in test_cases:
        run_test_case(**test_case)


467
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
468
469
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
470
    set_random_seed(seed)
471
    torch.set_default_device(device)
472
    batch_size = random.randint(1, 256)
473
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
474
475

    seq_group_metadata_list = []
Nick Hill's avatar
Nick Hill committed
476
    expected_tokens: List[Optional[List[int]]] = []
477
    seq_lens = []
478
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
479
480
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
481
482
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
Nick Hill's avatar
Nick Hill committed
483
484
            expected = [torch.argmax(fake_logits[i], dim=-1).item()]
        elif sampling_type in (1, 2):
485
486
487
488
489
490
491
492
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
493
494
495
496
497
498
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
499
500
501
502
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
503
        expected_tokens.append(expected)
504
505
506
507
508
509
510
511
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
512
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
513

514
    def test_sampling():
515
516
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
517
518
            seq_lens,
            query_lens=seq_lens,
519
            device=device,
520
            pin_memory=is_pin_memory_available())
521
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
522
523
524
525
526
527
528
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
            if metadata.sampling_params.use_beam_search:
                continue

529
530
531
532
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
533
534
535
536
537
538
539
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

            for n, nth_output in enumerate(sequence_output.samples):
540
541
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
542
543
544
                    # Ensure exact matches for greedy or random with seed
                    assert nth_output.output_token == expected_tokens[i][n]
                else:
545
546
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
Nick Hill's avatar
Nick Hill committed
547
548
549
                    assert nth_output.output_token in expected_tokens[i]

    # Test batch
550
    test_sampling()
Nick Hill's avatar
Nick Hill committed
551
552
553
554

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
555
                            expected_tokens, seq_lens):
Nick Hill's avatar
Nick Hill committed
556
557
558
559
560
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

561
562
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
563
    test_sampling()
Simon Mo's avatar
Simon Mo committed
564

565

566
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
567
568
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
569
570
571
572
573
574
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
575
                              device=device,
576
577
578
579
580
581
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
582
    sampler = MockLogitsSampler(fake_logits)
583
584
585
586
587
588
589
590
591

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
592
    seq_lens = []
593
594
595
596
597
598
599
600
601
602
603
604
605
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
606
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
607

608
609
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
610
611
        seq_lens,
        query_lens=seq_lens,
612
        device=device,
613
        pin_memory=is_pin_memory_available())
614
615
616

    sample_probs = None

617
    def mock_sample(probs, *args, **kwargs):
618
619
        nonlocal sample_probs
        sample_probs = probs
620
621
        return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                 for prob in probs], None)
622
623

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
624
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
625
626
627
628
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))