test_sampler.py 27.6 KB
Newer Older
1
import itertools
2
import random
3
from dataclasses import dataclass
4
from typing import Dict, List, Optional, Tuple
5
from unittest.mock import Mock, patch
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
import pytest
8
import torch
9
from transformers import GenerationConfig, GenerationMixin
10

11
import vllm.envs as envs
12
from vllm.model_executor.layers.sampler import Sampler
13
from vllm.model_executor.sampling_metadata import SamplingMetadata
14
from vllm.model_executor.utils import set_random_seed
15
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
16
from vllm.utils import Counter, is_pin_memory_available
17
18
19
20


class MockLogitsSampler(Sampler):

21
22
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
23
24
25
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
26
        return super().forward(*args, **kwargs)
27
28
29


def _prepare_test(
30
31
        batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
32
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
33
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
34
35
                             1e-2,
                             dtype=input_tensor.dtype)
36
    sampler = MockLogitsSampler(fake_logits)
37
    return input_tensor, fake_logits, sampler
38
39


40
VOCAB_SIZE = 32000
41
RANDOM_SEEDS = list(range(128))
42
43
44
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
45
46


Nick Hill's avatar
Nick Hill committed
47
48
49
50
51
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    sampling_params: SamplingParams,
52
    device: str,
Nick Hill's avatar
Nick Hill committed
53
):
54
55
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
56
57
58
59
60
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
61
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
62
                sampling_params=sampling_params,
63
64
                block_tables={0: [1]},
            ))
65
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
66

67
68
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
69
70
        seq_lens,
        query_lens=seq_lens,
71
        device=device,
72
        pin_memory=is_pin_memory_available())
73
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
74
75
76
77
78
79
80
81


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
82
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
83
84

    sampling_params = SamplingParams(temperature=0)
85
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
86
                                sampling_params, device)
87
88
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        for nth_output in sequence_output.samples:
90
91
92
93
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
94
95
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
96
    set_random_seed(seed)
97
    torch.set_default_device(device)
98
    batch_size = random.randint(1, 256)
99
    _, fake_logits, sampler = _prepare_test(batch_size)
100
101
102
103

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
104
105
106
107
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
108
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
109
                                sampling_params, device)
Nick Hill's avatar
Nick Hill committed
110
111
112
113
114
115
116
117
118
119
120
121

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
122
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
123

124
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
125
126
127
128
129
130
131
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
132
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
133
                                sampling_params, device)
134
135

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
136
        for nth_output in sequence_output.samples:
137
138
139
            assert nth_output.output_token == i


Nick Hill's avatar
Nick Hill committed
140
141
142
143
144
145
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
146
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
147
148
149
150
151
152

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
153
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
154
                                      sampling_params, device)
Nick Hill's avatar
Nick Hill committed
155

156
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
157
                                       sampling_params, device)
Nick Hill's avatar
Nick Hill committed
158
159
160
161

    assert first_sampler_output == second_sampler_output


162
163
164
165
166
167
168
169
170
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
171
                               *,
172
                               stop_token_ids: Optional[List[int]] = None,
173
                               prompt_logprobs: Optional[int] = None):
174
175
176
177
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
178
179
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
180
        )
181
        sampling_params.all_stop_token_ids.add(eos_token_id)
182
183
184
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
185
186
        seq_data = SequenceData.from_seqs(
            random.choices(range(0, VOCAB_SIZE), k=num_input))
187
188
189
190
191
192
193
194
195
196
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
197
        sequence_metadata_list: List[SequenceGroupMetadata] = []
198
199
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

217
218
            seq_data: Dict[int, SequenceData] = {}
            seq_group_penalization: List[bool] = []
219
220
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
221
                num_generated = 0 if is_prompt else random.randint(1, 100)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
349
                is_prompt=False,
350
                seq_data={
351
352
353
354
355
356
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
357
358
                },
                sampling_params=create_sampling_params(
359
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
360
                block_tables={},
361
            ),
362
363
364
365
366
367
368
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
369
            prompt_with_penalization_and_prompt_logprobs,
370
            stop_penalizing_after_min_tokens,
371
372
            prompt_combination,
            decode_combination,
373
374
375
376
        ]
    else:
        test_cases = [generate_test_case()]

377
378
    def run_test_case(*, expected_penalization: List[bool],
                      seq_group_metadata_list: List[SequenceGroupMetadata]):
379
380
381
382
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
383
384

        batch_size = 0
385
386
        seq_lens: List[int] = []
        sampling_params_per_row: List[SamplingParams] = []
387
388
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
389
390
391
392
393
394

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
395
                seq_lens.append(prompt_len)
396

397
                assert sgm.sampling_params is not None
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
412

413
        _, fake_logits, sampler = _prepare_test(batch_size)
414
        sampling_metadata = SamplingMetadata.prepare(
415
            seq_group_metadata_list,
416
            seq_lens=seq_lens if seq_lens else None,
417
            query_lens=seq_lens if seq_lens else [1] * batch_size,
418
            device=device,
419
            pin_memory=is_pin_memory_available())
420
421
422
423
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
424
                zip(expected_penalization, sampling_params_per_row)):
425

426
            tokens_to_check = sampling_params.all_stop_token_ids
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

    for test_case in test_cases:
        run_test_case(**test_case)


449
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
450
451
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
452
    set_random_seed(seed)
453
    torch.set_default_device(device)
454
    batch_size = random.randint(1, 256)
455
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
456

457
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
Nick Hill's avatar
Nick Hill committed
458
    expected_tokens: List[Optional[List[int]]] = []
459
    seq_lens: List[int] = []
460
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
461
        expected: Optional[List[int]] = None
462
        sampling_type = random.randint(0, 2)
463
464
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
465
            expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
Nick Hill's avatar
Nick Hill committed
466
        elif sampling_type in (1, 2):
467
468
469
470
471
472
473
474
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
475
476
477
478
479
480
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
481

Nick Hill's avatar
Nick Hill committed
482
        expected_tokens.append(expected)
483
484
485
486
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
487
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
488
489
490
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
491
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
492

493
494
    generators: Dict[str, torch.Generator] = {}

495
    def test_sampling():
496
497
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
498
499
            seq_lens,
            query_lens=seq_lens,
500
            device=device,
501
502
            pin_memory=is_pin_memory_available(),
            generators=generators)
503
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
504
505
506
507
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
508
509
            assert metadata.sampling_params is not None

510
511
512
513
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
514
515
516
517
518
519
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

520
521
522
            expected_tokens_item = expected_tokens[i]
            assert expected_tokens_item is not None

Nick Hill's avatar
Nick Hill committed
523
            for n, nth_output in enumerate(sequence_output.samples):
524
525
                assert metadata.sampling_params is not None

526
527
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
528
                    # Ensure exact matches for greedy or random with seed
529
                    assert nth_output.output_token == expected_tokens_item[n]
Nick Hill's avatar
Nick Hill committed
530
                else:
531
532
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
533
                    assert nth_output.output_token in expected_tokens_item
Nick Hill's avatar
Nick Hill committed
534
535

    # Test batch
536
    test_sampling()
Nick Hill's avatar
Nick Hill committed
537
538
539
540

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
541
                            expected_tokens, seq_lens):
Nick Hill's avatar
Nick Hill committed
542
543
544
545
546
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

547
548
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
549
    test_sampling()
Simon Mo's avatar
Simon Mo committed
550

551

552
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
553
554
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
555
556
557
558
559
560
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
561
                              device=device,
562
563
564
565
566
567
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
568
    sampler = MockLogitsSampler(fake_logits)
569
570
571
572
573

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
574
575
576
577
578
579
580
581
582
583
584
585
586

    @dataclass
    class MockConfig:
        is_encoder_decoder: bool = False

    generation_model.config = MockConfig()  # needed by the following method
    generation_model._prepare_special_tokens(generation_config, device=device)
    processors = generation_model._get_logits_processor(generation_config,
                                                        None,
                                                        None,
                                                        None, [],
                                                        device=device)
    assert len(processors) == 2  # top_p and top_k
587

588
589
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
590
591
592
593
594
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
595
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
596
597
598
599
600
601
602
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
603
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
604

605
606
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
607
608
        seq_lens,
        query_lens=seq_lens,
609
        device=device,
610
        pin_memory=is_pin_memory_available())
611
612
613

    sample_probs = None

614
    def mock_sample(probs, *args, **kwargs):
615
616
        nonlocal sample_probs
        sample_probs = probs
617
618
        return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                 for prob in probs], None)
619

620
621
622
623
    # top-k and top-p is only calculated when flashinfer kernel is not available
    with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
         patch("vllm.model_executor.layers.sampler."
               "flashinfer_top_k_top_p_sampling", None):
624
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
625
626
627

    assert sample_probs is not None

628
    hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
629
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
630
    torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
631
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
632
633


634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
    if not envs.VLLM_USE_FLASHINFER_SAMPLER:
        pytest.skip("Flashinfer sampler is disabled")

    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)

    def failing_flashinfer_sampling(*_args, **_kwargs):
        return None, torch.zeros(batch_size, device=device, dtype=torch.int32)

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                sampling_params, device)

    with patch(
            "vllm.model_executor.layers.sampler."
            "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
        fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                             sampling_params, device)

    assert sampler_output == fallback_sampler_output


665
666
667
668
669
670
671
672
673
674
675
676
677
678
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

    vocab_size = 8

    def test_sampling_params(sampling_params: List[SamplingParams]):

        seq_group_metadata_list: List[SequenceGroupMetadata] = []
        seq_lens: List[int] = []
        for i in range(2):
            seq_group_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{i}",
                    is_prompt=True,
679
                    seq_data={0: SequenceData.from_seqs([1, 2, 3])},
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
                    sampling_params=sampling_params[i],
                    block_tables={0: [1]},
                ))
            seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            query_lens=seq_lens,
            device=device,
            pin_memory=is_pin_memory_available())

        fake_logits = torch.full((2, vocab_size),
                                 1e-2,
                                 device=device,
                                 dtype=torch.float16)

        fake_logits[:, 5] = 1.1e-2
        fake_logits[:, 1] = 1.2e-2

        sampler = MockLogitsSampler(fake_logits)

        sampler_output = sampler(logits=fake_logits,
                                 sampling_metadata=sampling_metadata)

        generated_tokens = []
        for output in sampler_output:
            generated_tokens.append(output.samples[0].output_token)

        return generated_tokens

    # one configuration is greedy with repetition_penalty
    sampling_params_rep = SamplingParams(
        temperature=0.0,
        repetition_penalty=2.0,
    )

    # other configuration is sampling w/o repetition_penalty
    sampling_params_sample = SamplingParams(
        temperature=1.0,
        top_k=1,
        seed=42,
    )

    tokens1 = test_sampling_params(
        [sampling_params_rep, sampling_params_sample])

    tokens2 = test_sampling_params(
        [sampling_params_sample, sampling_params_rep])

    assert tokens1[0] == tokens2[1]
    assert tokens1[1] == tokens2[0]
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
    set_random_seed(42)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)
    sampler.include_gpu_probs_tensor = True
    sampler.should_modify_greedy_probs_inplace = False

    sampling_params = SamplingParams(temperature=0)

    mock_inplace = Mock()
    with patch(
            "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
            mock_inplace):

        sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                    sampling_params, device)
        mock_inplace.assert_not_called()

    assert sampler_output.sampled_token_probs is not None
    assert sampler_output.logprobs is not None
    assert sampler_output.sampled_token_ids is not None