test_sampler.py 28.6 KB
Newer Older
1
import itertools
2
import random
3
from dataclasses import dataclass
4
from typing import Dict, List, Optional, Tuple
5
from unittest.mock import Mock, patch
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
import pytest
8
import torch
9
from transformers import GenerationConfig, GenerationMixin
10

11
import vllm.envs as envs
12
from vllm.model_executor.layers.sampler import Sampler
13
from vllm.model_executor.sampling_metadata import SamplingMetadata
14
from vllm.model_executor.utils import set_random_seed
15
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
16
from vllm.utils import Counter, is_pin_memory_available
17
18
19
20


class MockLogitsSampler(Sampler):

21
22
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
23
24
25
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
26
        return super().forward(*args, **kwargs)
27
28
29


def _prepare_test(
30
31
        batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
32
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
33
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
34
35
                             1e-2,
                             dtype=input_tensor.dtype)
36
    sampler = MockLogitsSampler(fake_logits)
37
    return input_tensor, fake_logits, sampler
38
39


40
VOCAB_SIZE = 32000
41
RANDOM_SEEDS = list(range(128))
42
43
44
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
45
46


Nick Hill's avatar
Nick Hill committed
47
48
49
50
51
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    sampling_params: SamplingParams,
52
    device: str,
Nick Hill's avatar
Nick Hill committed
53
):
54
55
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
56
57
58
59
60
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
61
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
62
                sampling_params=sampling_params,
63
64
                block_tables={0: [1]},
            ))
65
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
66

67
68
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
69
70
        seq_lens,
        query_lens=seq_lens,
71
        device=device,
72
        pin_memory=is_pin_memory_available())
73
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
74
75
76
77
78
79
80
81


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
82
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
83
84

    sampling_params = SamplingParams(temperature=0)
85
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
86
                                sampling_params, device)
87
88
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        for nth_output in sequence_output.samples:
90
91
92
93
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
94
95
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
96
    set_random_seed(seed)
97
    torch.set_default_device(device)
98
    batch_size = random.randint(1, 256)
99
    _, fake_logits, sampler = _prepare_test(batch_size)
100
101
102
103

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
104
105
106
107
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
108
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
109
                                sampling_params, device)
Nick Hill's avatar
Nick Hill committed
110
111
112
113
114
115
116
117
118
119
120
121

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
122
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
123

124
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
125
126
127
128
129
130
131
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
132
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
133
                                sampling_params, device)
134
135

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
136
        for nth_output in sequence_output.samples:
137
138
139
            assert nth_output.output_token == i


Nick Hill's avatar
Nick Hill committed
140
141
142
143
144
145
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
146
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
147
148
149
150
151
152

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
153
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
154
                                      sampling_params, device)
Nick Hill's avatar
Nick Hill committed
155

156
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
157
                                       sampling_params, device)
Nick Hill's avatar
Nick Hill committed
158
159
160
161

    assert first_sampler_output == second_sampler_output


162
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
163
164
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
165
    set_random_seed(seed)
166
    torch.set_default_device(device)
167
    batch_size = random.randint(1, 256)
168
    _, fake_logits, sampler = _prepare_test(batch_size)
169

Nick Hill's avatar
Nick Hill committed
170
171
172
173
174
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
175
    _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
176
177
178
179
180
181
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.


182
183
184
185
186
187
188
189
190
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
191
                               *,
192
                               stop_token_ids: Optional[List[int]] = None,
193
                               prompt_logprobs: Optional[int] = None):
194
195
196
197
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
198
199
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
200
        )
201
        sampling_params.all_stop_token_ids.add(eos_token_id)
202
203
204
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
205
206
        seq_data = SequenceData.from_seqs(
            random.choices(range(0, VOCAB_SIZE), k=num_input))
207
208
209
210
211
212
213
214
215
216
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
217
        sequence_metadata_list: List[SequenceGroupMetadata] = []
218
219
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

237
238
            seq_data: Dict[int, SequenceData] = {}
            seq_group_penalization: List[bool] = []
239
240
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
241
                num_generated = 0 if is_prompt else random.randint(1, 100)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
369
                is_prompt=False,
370
                seq_data={
371
372
373
374
375
376
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
377
378
                },
                sampling_params=create_sampling_params(
379
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
380
                block_tables={},
381
            ),
382
383
384
385
386
387
388
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
389
            prompt_with_penalization_and_prompt_logprobs,
390
            stop_penalizing_after_min_tokens,
391
392
            prompt_combination,
            decode_combination,
393
394
395
396
        ]
    else:
        test_cases = [generate_test_case()]

397
398
    def run_test_case(*, expected_penalization: List[bool],
                      seq_group_metadata_list: List[SequenceGroupMetadata]):
399
400
401
402
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
403
404

        batch_size = 0
405
406
        seq_lens: List[int] = []
        sampling_params_per_row: List[SamplingParams] = []
407
408
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
409
410
411
412
413
414

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
415
                seq_lens.append(prompt_len)
416

417
                assert sgm.sampling_params is not None
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
432

433
        _, fake_logits, sampler = _prepare_test(batch_size)
434
        sampling_metadata = SamplingMetadata.prepare(
435
            seq_group_metadata_list,
436
            seq_lens=seq_lens if seq_lens else None,
437
            query_lens=seq_lens if seq_lens else [1] * batch_size,
438
            device=device,
439
            pin_memory=is_pin_memory_available())
440
441
442
443
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
444
                zip(expected_penalization, sampling_params_per_row)):
445

446
            tokens_to_check = sampling_params.all_stop_token_ids
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

    for test_case in test_cases:
        run_test_case(**test_case)


469
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
470
471
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
472
    set_random_seed(seed)
473
    torch.set_default_device(device)
474
    batch_size = random.randint(1, 256)
475
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
476

477
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
Nick Hill's avatar
Nick Hill committed
478
    expected_tokens: List[Optional[List[int]]] = []
479
    seq_lens: List[int] = []
480
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
481
482
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
483
484
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
485
            expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
Nick Hill's avatar
Nick Hill committed
486
        elif sampling_type in (1, 2):
487
488
489
490
491
492
493
494
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
495
496
497
498
499
500
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
501
502
503
504
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
505
        expected_tokens.append(expected)
506
507
508
509
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
510
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
511
512
513
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
514
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
515

516
517
    generators: Dict[str, torch.Generator] = {}

518
    def test_sampling():
519
520
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
521
522
            seq_lens,
            query_lens=seq_lens,
523
            device=device,
524
525
            pin_memory=is_pin_memory_available(),
            generators=generators)
526
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
527
528
529
530
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
531
532
            assert metadata.sampling_params is not None

Nick Hill's avatar
Nick Hill committed
533
534
535
            if metadata.sampling_params.use_beam_search:
                continue

536
537
538
539
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
540
541
542
543
544
545
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

546
547
548
            expected_tokens_item = expected_tokens[i]
            assert expected_tokens_item is not None

Nick Hill's avatar
Nick Hill committed
549
            for n, nth_output in enumerate(sequence_output.samples):
550
551
                assert metadata.sampling_params is not None

552
553
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
554
                    # Ensure exact matches for greedy or random with seed
555
                    assert nth_output.output_token == expected_tokens_item[n]
Nick Hill's avatar
Nick Hill committed
556
                else:
557
558
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
559
                    assert nth_output.output_token in expected_tokens_item
Nick Hill's avatar
Nick Hill committed
560
561

    # Test batch
562
    test_sampling()
Nick Hill's avatar
Nick Hill committed
563
564
565
566

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
567
                            expected_tokens, seq_lens):
Nick Hill's avatar
Nick Hill committed
568
569
570
571
572
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

573
574
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
575
    test_sampling()
Simon Mo's avatar
Simon Mo committed
576

577

578
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
579
580
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
581
582
583
584
585
586
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
587
                              device=device,
588
589
590
591
592
593
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
594
    sampler = MockLogitsSampler(fake_logits)
595
596
597
598
599

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
600
601
602
603
604
605
606
607
608
609
610
611
612

    @dataclass
    class MockConfig:
        is_encoder_decoder: bool = False

    generation_model.config = MockConfig()  # needed by the following method
    generation_model._prepare_special_tokens(generation_config, device=device)
    processors = generation_model._get_logits_processor(generation_config,
                                                        None,
                                                        None,
                                                        None, [],
                                                        device=device)
    assert len(processors) == 2  # top_p and top_k
613

614
615
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
616
617
618
619
620
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
621
                seq_data={0: SequenceData.from_seqs([1, 2, 3])},
622
623
624
625
626
627
628
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
629
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
630

631
632
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
633
634
        seq_lens,
        query_lens=seq_lens,
635
        device=device,
636
        pin_memory=is_pin_memory_available())
637
638
639

    sample_probs = None

640
    def mock_sample(probs, *args, **kwargs):
641
642
        nonlocal sample_probs
        sample_probs = probs
643
644
        return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                 for prob in probs], None)
645

646
647
648
649
    # top-k and top-p is only calculated when flashinfer kernel is not available
    with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
         patch("vllm.model_executor.layers.sampler."
               "flashinfer_top_k_top_p_sampling", None):
650
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
651
652
653

    assert sample_probs is not None

654
    hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
655
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
656
    torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
657
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
658
659


660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
    if not envs.VLLM_USE_FLASHINFER_SAMPLER:
        pytest.skip("Flashinfer sampler is disabled")

    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)

    def failing_flashinfer_sampling(*_args, **_kwargs):
        return None, torch.zeros(batch_size, device=device, dtype=torch.int32)

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                sampling_params, device)

    with patch(
            "vllm.model_executor.layers.sampler."
            "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
        fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                             sampling_params, device)

    assert sampler_output == fallback_sampler_output


691
692
693
694
695
696
697
698
699
700
701
702
703
704
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

    vocab_size = 8

    def test_sampling_params(sampling_params: List[SamplingParams]):

        seq_group_metadata_list: List[SequenceGroupMetadata] = []
        seq_lens: List[int] = []
        for i in range(2):
            seq_group_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{i}",
                    is_prompt=True,
705
                    seq_data={0: SequenceData.from_seqs([1, 2, 3])},
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
                    sampling_params=sampling_params[i],
                    block_tables={0: [1]},
                ))
            seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            query_lens=seq_lens,
            device=device,
            pin_memory=is_pin_memory_available())

        fake_logits = torch.full((2, vocab_size),
                                 1e-2,
                                 device=device,
                                 dtype=torch.float16)

        fake_logits[:, 5] = 1.1e-2
        fake_logits[:, 1] = 1.2e-2

        sampler = MockLogitsSampler(fake_logits)

        sampler_output = sampler(logits=fake_logits,
                                 sampling_metadata=sampling_metadata)

        generated_tokens = []
        for output in sampler_output:
            generated_tokens.append(output.samples[0].output_token)

        return generated_tokens

    # one configuration is greedy with repetition_penalty
    sampling_params_rep = SamplingParams(
        temperature=0.0,
        repetition_penalty=2.0,
    )

    # other configuration is sampling w/o repetition_penalty
    sampling_params_sample = SamplingParams(
        temperature=1.0,
        top_k=1,
        seed=42,
    )

    tokens1 = test_sampling_params(
        [sampling_params_rep, sampling_params_sample])

    tokens2 = test_sampling_params(
        [sampling_params_sample, sampling_params_rep])

    assert tokens1[0] == tokens2[1]
    assert tokens1[1] == tokens2[0]
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
    set_random_seed(42)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)
    sampler.include_gpu_probs_tensor = True
    sampler.should_modify_greedy_probs_inplace = False

    sampling_params = SamplingParams(temperature=0)

    mock_inplace = Mock()
    with patch(
            "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
            mock_inplace):

        sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                    sampling_params, device)
        mock_inplace.assert_not_called()

    assert sampler_output.sampled_token_probs is not None
    assert sampler_output.logprobs is not None
    assert sampler_output.sampled_token_ids is not None