test_sampler.py 28 KB
Newer Older
1
import itertools
2
import random
3
from typing import Dict, List, Optional, Tuple
4
from unittest.mock import Mock, patch
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
import pytest
7
import torch
8
from transformers import GenerationConfig, GenerationMixin
9

10
import vllm.envs as envs
11
from vllm.model_executor.layers.sampler import Sampler
12
from vllm.model_executor.sampling_metadata import SamplingMetadata
13
from vllm.model_executor.utils import set_random_seed
14
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
15
from vllm.utils import Counter, is_pin_memory_available
16
17
18
19


class MockLogitsSampler(Sampler):

20
21
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
22
23
24
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
25
        return super().forward(*args, **kwargs)
26
27
28


def _prepare_test(
29
30
        batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
31
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
32
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
33
34
                             1e-2,
                             dtype=input_tensor.dtype)
35
    sampler = MockLogitsSampler(fake_logits)
36
    return input_tensor, fake_logits, sampler
37
38


39
VOCAB_SIZE = 32000
40
RANDOM_SEEDS = list(range(128))
41
42
43
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
44
45


Nick Hill's avatar
Nick Hill committed
46
47
48
49
50
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    sampling_params: SamplingParams,
51
    device: str,
Nick Hill's avatar
Nick Hill committed
52
):
53
54
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
55
56
57
58
59
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
60
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
61
                sampling_params=sampling_params,
62
63
                block_tables={0: [1]},
            ))
64
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
65

66
67
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
68
69
        seq_lens,
        query_lens=seq_lens,
70
        device=device,
71
        pin_memory=is_pin_memory_available())
72
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
73
74
75
76
77
78
79
80


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
81
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
82
83

    sampling_params = SamplingParams(temperature=0)
84
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
85
                                sampling_params, device)
86
87
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        for nth_output in sequence_output.samples:
89
90
91
92
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
93
94
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
95
    set_random_seed(seed)
96
    torch.set_default_device(device)
97
    batch_size = random.randint(1, 256)
98
    _, fake_logits, sampler = _prepare_test(batch_size)
99
100
101
102

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
103
104
105
106
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
107
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
108
                                sampling_params, device)
Nick Hill's avatar
Nick Hill committed
109
110
111
112
113
114
115
116
117
118
119
120

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
121
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
122

123
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
124
125
126
127
128
129
130
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
131
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
132
                                sampling_params, device)
133
134

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        for nth_output in sequence_output.samples:
136
137
138
            assert nth_output.output_token == i


Nick Hill's avatar
Nick Hill committed
139
140
141
142
143
144
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
145
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
146
147
148
149
150
151

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
152
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
153
                                      sampling_params, device)
Nick Hill's avatar
Nick Hill committed
154

155
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
156
                                       sampling_params, device)
Nick Hill's avatar
Nick Hill committed
157
158
159
160

    assert first_sampler_output == second_sampler_output


161
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
162
163
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
164
    set_random_seed(seed)
165
    torch.set_default_device(device)
166
    batch_size = random.randint(1, 256)
167
    _, fake_logits, sampler = _prepare_test(batch_size)
168

Nick Hill's avatar
Nick Hill committed
169
170
171
172
173
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
174
    _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
175
176
177
178
179
180
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.


181
182
183
184
185
186
187
188
189
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
190
                               *,
191
                               stop_token_ids: Optional[List[int]] = None,
192
                               prompt_logprobs: Optional[int] = None):
193
194
195
196
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
197
198
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
199
        )
200
        sampling_params.all_stop_token_ids.add(eos_token_id)
201
202
203
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
204
205
        seq_data = SequenceData.from_seqs(
            random.choices(range(0, VOCAB_SIZE), k=num_input))
206
207
208
209
210
211
212
213
214
215
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
216
        sequence_metadata_list: List[SequenceGroupMetadata] = []
217
218
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

236
237
            seq_data: Dict[int, SequenceData] = {}
            seq_group_penalization: List[bool] = []
238
239
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
240
                num_generated = 0 if is_prompt else random.randint(1, 100)
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
368
                is_prompt=False,
369
                seq_data={
370
371
372
373
374
375
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
376
377
                },
                sampling_params=create_sampling_params(
378
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
379
                block_tables={},
380
            ),
381
382
383
384
385
386
387
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
388
            prompt_with_penalization_and_prompt_logprobs,
389
            stop_penalizing_after_min_tokens,
390
391
            prompt_combination,
            decode_combination,
392
393
394
395
        ]
    else:
        test_cases = [generate_test_case()]

396
397
    def run_test_case(*, expected_penalization: List[bool],
                      seq_group_metadata_list: List[SequenceGroupMetadata]):
398
399
400
401
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
402
403

        batch_size = 0
404
405
        seq_lens: List[int] = []
        sampling_params_per_row: List[SamplingParams] = []
406
407
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
408
409
410
411
412
413

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
414
                seq_lens.append(prompt_len)
415

416
                assert sgm.sampling_params is not None
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
431

432
        _, fake_logits, sampler = _prepare_test(batch_size)
433
        sampling_metadata = SamplingMetadata.prepare(
434
            seq_group_metadata_list,
435
436
            seq_lens=seq_lens if seq_lens else None,
            query_lens=seq_lens if seq_lens else None,
437
            device=device,
438
            pin_memory=is_pin_memory_available())
439
440
441
442
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
443
                zip(expected_penalization, sampling_params_per_row)):
444

445
            tokens_to_check = sampling_params.all_stop_token_ids
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

    for test_case in test_cases:
        run_test_case(**test_case)


468
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
469
470
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
471
    set_random_seed(seed)
472
    torch.set_default_device(device)
473
    batch_size = random.randint(1, 256)
474
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
475

476
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
Nick Hill's avatar
Nick Hill committed
477
    expected_tokens: List[Optional[List[int]]] = []
478
    seq_lens: List[int] = []
479
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
480
481
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
482
483
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
484
            expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
Nick Hill's avatar
Nick Hill committed
485
        elif sampling_type in (1, 2):
486
487
488
489
490
491
492
493
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
494
495
496
497
498
499
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
500
501
502
503
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
504
        expected_tokens.append(expected)
505
506
507
508
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
509
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
510
511
512
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
513
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
514

515
516
    generators: Dict[str, torch.Generator] = {}

517
    def test_sampling():
518
519
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
520
521
            seq_lens,
            query_lens=seq_lens,
522
            device=device,
523
524
            pin_memory=is_pin_memory_available(),
            generators=generators)
525
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
526
527
528
529
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
530
531
            assert metadata.sampling_params is not None

Nick Hill's avatar
Nick Hill committed
532
533
534
            if metadata.sampling_params.use_beam_search:
                continue

535
536
537
538
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
539
540
541
542
543
544
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

545
546
547
            expected_tokens_item = expected_tokens[i]
            assert expected_tokens_item is not None

Nick Hill's avatar
Nick Hill committed
548
            for n, nth_output in enumerate(sequence_output.samples):
549
550
                assert metadata.sampling_params is not None

551
552
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
553
                    # Ensure exact matches for greedy or random with seed
554
                    assert nth_output.output_token == expected_tokens_item[n]
Nick Hill's avatar
Nick Hill committed
555
                else:
556
557
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
558
                    assert nth_output.output_token in expected_tokens_item
Nick Hill's avatar
Nick Hill committed
559
560

    # Test batch
561
    test_sampling()
Nick Hill's avatar
Nick Hill committed
562
563
564
565

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
566
                            expected_tokens, seq_lens):
Nick Hill's avatar
Nick Hill committed
567
568
569
570
571
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

572
573
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
574
    test_sampling()
Simon Mo's avatar
Simon Mo committed
575

576

577
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
578
579
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
580
581
582
583
584
585
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
586
                              device=device,
587
588
589
590
591
592
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
593
    sampler = MockLogitsSampler(fake_logits)
594
595
596
597
598

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
599
    warpers = generation_model._get_logits_warper(generation_config, device)
600
601
    assert len(warpers) == 2  # top_p and top_k

602
603
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
604
605
606
607
608
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
609
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
610
611
612
613
614
615
616
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
617
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
618

619
620
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
621
622
        seq_lens,
        query_lens=seq_lens,
623
        device=device,
624
        pin_memory=is_pin_memory_available())
625
626
627

    sample_probs = None

628
    def mock_sample(probs, *args, **kwargs):
629
630
        nonlocal sample_probs
        sample_probs = probs
631
632
        return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                 for prob in probs], None)
633

634
635
636
637
    # top-k and top-p is only calculated when flashinfer kernel is not available
    with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
         patch("vllm.model_executor.layers.sampler."
               "flashinfer_top_k_top_p_sampling", None):
638
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
639
640
641

    assert sample_probs is not None

642
643
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
644
    torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
645
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
646
647


648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
    if not envs.VLLM_USE_FLASHINFER_SAMPLER:
        pytest.skip("Flashinfer sampler is disabled")

    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)

    def failing_flashinfer_sampling(*_args, **_kwargs):
        return None, torch.zeros(batch_size, device=device, dtype=torch.int32)

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                sampling_params, device)

    with patch(
            "vllm.model_executor.layers.sampler."
            "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
        fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                             sampling_params, device)

    assert sampler_output == fallback_sampler_output


679
680
681
682
683
684
685
686
687
688
689
690
691
692
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

    vocab_size = 8

    def test_sampling_params(sampling_params: List[SamplingParams]):

        seq_group_metadata_list: List[SequenceGroupMetadata] = []
        seq_lens: List[int] = []
        for i in range(2):
            seq_group_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{i}",
                    is_prompt=True,
693
                    seq_data={0: SequenceData.from_seqs([1, 2, 3])},
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
                    sampling_params=sampling_params[i],
                    block_tables={0: [1]},
                ))
            seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            query_lens=seq_lens,
            device=device,
            pin_memory=is_pin_memory_available())

        fake_logits = torch.full((2, vocab_size),
                                 1e-2,
                                 device=device,
                                 dtype=torch.float16)

        fake_logits[:, 5] = 1.1e-2
        fake_logits[:, 1] = 1.2e-2

        sampler = MockLogitsSampler(fake_logits)

        sampler_output = sampler(logits=fake_logits,
                                 sampling_metadata=sampling_metadata)

        generated_tokens = []
        for output in sampler_output:
            generated_tokens.append(output.samples[0].output_token)

        return generated_tokens

    # one configuration is greedy with repetition_penalty
    sampling_params_rep = SamplingParams(
        temperature=0.0,
        repetition_penalty=2.0,
    )

    # other configuration is sampling w/o repetition_penalty
    sampling_params_sample = SamplingParams(
        temperature=1.0,
        top_k=1,
        seed=42,
    )

    tokens1 = test_sampling_params(
        [sampling_params_rep, sampling_params_sample])

    tokens2 = test_sampling_params(
        [sampling_params_sample, sampling_params_rep])

    assert tokens1[0] == tokens2[1]
    assert tokens1[1] == tokens2[0]
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
    set_random_seed(42)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)
    sampler.include_gpu_probs_tensor = True
    sampler.should_modify_greedy_probs_inplace = False

    sampling_params = SamplingParams(temperature=0)

    mock_inplace = Mock()
    with patch(
            "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
            mock_inplace):

        sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                    sampling_params, device)
        mock_inplace.assert_not_called()

    assert sampler_output.sampled_token_probs is not None
    assert sampler_output.logprobs is not None
    assert sampler_output.sampled_token_ids is not None