test_sampler.py 28.3 KB
Newer Older
1
import itertools
2
import random
3
from array import array
4
from typing import Dict, List, Optional, Tuple
5
from unittest.mock import Mock, patch
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
import pytest
8
import torch
9
from transformers import GenerationConfig, GenerationMixin
10

11
import vllm.envs as envs
12
from vllm.model_executor.layers.sampler import Sampler
13
from vllm.model_executor.sampling_metadata import SamplingMetadata
14
from vllm.model_executor.utils import set_random_seed
15
16
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
                           SequenceData, SequenceGroupMetadata)
17
from vllm.utils import Counter, is_pin_memory_available
18
19
20
21


class MockLogitsSampler(Sampler):

22
23
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
24
25
26
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
27
        return super().forward(*args, **kwargs)
28
29
30


def _prepare_test(
31
32
        batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
33
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
34
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
35
36
                             1e-2,
                             dtype=input_tensor.dtype)
37
    sampler = MockLogitsSampler(fake_logits)
38
    return input_tensor, fake_logits, sampler
39
40


41
VOCAB_SIZE = 32000
42
RANDOM_SEEDS = list(range(128))
43
44
45
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
46
47


Nick Hill's avatar
Nick Hill committed
48
49
50
51
52
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    sampling_params: SamplingParams,
53
    device: str,
Nick Hill's avatar
Nick Hill committed
54
):
55
56
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
57
58
59
60
61
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
62
63
64
                seq_data={
                    0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
                },
Nick Hill's avatar
Nick Hill committed
65
                sampling_params=sampling_params,
66
67
                block_tables={0: [1]},
            ))
68
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
69

70
71
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
72
73
        seq_lens,
        query_lens=seq_lens,
74
        device=device,
75
        pin_memory=is_pin_memory_available())
76
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
77
78
79
80
81
82
83
84


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
85
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
86
87

    sampling_params = SamplingParams(temperature=0)
88
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
89
                                sampling_params, device)
90
91
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        for nth_output in sequence_output.samples:
93
94
95
96
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
97
98
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
99
    set_random_seed(seed)
100
    torch.set_default_device(device)
101
    batch_size = random.randint(1, 256)
102
    _, fake_logits, sampler = _prepare_test(batch_size)
103
104
105
106

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
107
108
109
110
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
111
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
112
                                sampling_params, device)
Nick Hill's avatar
Nick Hill committed
113
114
115
116
117
118
119
120
121
122
123
124

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
125
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
126

127
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
128
129
130
131
132
133
134
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
135
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
136
                                sampling_params, device)
137
138

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        for nth_output in sequence_output.samples:
140
141
142
            assert nth_output.output_token == i


Nick Hill's avatar
Nick Hill committed
143
144
145
146
147
148
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
149
    _, fake_logits, sampler = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
150
151
152
153
154
155

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
156
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
157
                                      sampling_params, device)
Nick Hill's avatar
Nick Hill committed
158

159
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
160
                                       sampling_params, device)
Nick Hill's avatar
Nick Hill committed
161
162
163
164

    assert first_sampler_output == second_sampler_output


165
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
166
167
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
168
    set_random_seed(seed)
169
    torch.set_default_device(device)
170
    batch_size = random.randint(1, 256)
171
    _, fake_logits, sampler = _prepare_test(batch_size)
172

Nick Hill's avatar
Nick Hill committed
173
174
175
176
177
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
178
    _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
179
180
181
182
183
184
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.


185
186
187
188
189
190
191
192
193
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
194
                               *,
195
                               stop_token_ids: Optional[List[int]] = None,
196
                               prompt_logprobs: Optional[int] = None):
197
198
199
200
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
201
202
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
203
        )
204
        sampling_params.all_stop_token_ids.add(eos_token_id)
205
206
207
208
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
        seq_data = SequenceData(
209
210
            array(VLLM_TOKEN_ID_ARRAY_TYPE,
                  random.choices(range(0, VOCAB_SIZE), k=num_input)))
211
212
213
214
215
216
217
218
219
220
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
221
        sequence_metadata_list: List[SequenceGroupMetadata] = []
222
223
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

241
242
            seq_data: Dict[int, SequenceData] = {}
            seq_group_penalization: List[bool] = []
243
244
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
245
                num_generated = 0 if is_prompt else random.randint(1, 100)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
373
                is_prompt=False,
374
                seq_data={
375
376
377
378
379
380
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
381
382
                },
                sampling_params=create_sampling_params(
383
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
384
                block_tables={},
385
            ),
386
387
388
389
390
391
392
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
393
            prompt_with_penalization_and_prompt_logprobs,
394
            stop_penalizing_after_min_tokens,
395
396
            prompt_combination,
            decode_combination,
397
398
399
400
        ]
    else:
        test_cases = [generate_test_case()]

401
402
    def run_test_case(*, expected_penalization: List[bool],
                      seq_group_metadata_list: List[SequenceGroupMetadata]):
403
404
405
406
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
407
408

        batch_size = 0
409
410
        seq_lens: List[int] = []
        sampling_params_per_row: List[SamplingParams] = []
411
412
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
413
414
415
416
417
418

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
419
                seq_lens.append(prompt_len)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
435

436
        _, fake_logits, sampler = _prepare_test(batch_size)
437
        sampling_metadata = SamplingMetadata.prepare(
438
            seq_group_metadata_list,
439
440
            seq_lens=seq_lens if seq_lens else None,
            query_lens=seq_lens if seq_lens else None,
441
            device=device,
442
            pin_memory=is_pin_memory_available())
443
444
445
446
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
447
                zip(expected_penalization, sampling_params_per_row)):
448

449
            tokens_to_check = sampling_params.all_stop_token_ids
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

    for test_case in test_cases:
        run_test_case(**test_case)


472
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
473
474
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
475
    set_random_seed(seed)
476
    torch.set_default_device(device)
477
    batch_size = random.randint(1, 256)
478
    input_tensor, fake_logits, sampler = _prepare_test(batch_size)
479

480
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
Nick Hill's avatar
Nick Hill committed
481
    expected_tokens: List[Optional[List[int]]] = []
482
    seq_lens: List[int] = []
483
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
484
485
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
486
487
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
488
            expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
Nick Hill's avatar
Nick Hill committed
489
        elif sampling_type in (1, 2):
490
491
492
493
494
495
496
497
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
498
499
500
501
502
503
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
504
505
506
507
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
508
        expected_tokens.append(expected)
509
510
511
512
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
513
514
515
                seq_data={
                    0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
                },
516
517
518
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
519
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
520

521
522
    generators: Dict[str, torch.Generator] = {}

523
    def test_sampling():
524
525
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
526
527
            seq_lens,
            query_lens=seq_lens,
528
            device=device,
529
530
            pin_memory=is_pin_memory_available(),
            generators=generators)
531
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
532
533
534
535
536
537
538
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
            if metadata.sampling_params.use_beam_search:
                continue

539
540
541
542
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
543
544
545
546
547
548
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

549
550
551
            expected_tokens_item = expected_tokens[i]
            assert expected_tokens_item is not None

Nick Hill's avatar
Nick Hill committed
552
            for n, nth_output in enumerate(sequence_output.samples):
553
554
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
555
                    # Ensure exact matches for greedy or random with seed
556
                    assert nth_output.output_token == expected_tokens_item[n]
Nick Hill's avatar
Nick Hill committed
557
                else:
558
559
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
560
                    assert nth_output.output_token in expected_tokens_item
Nick Hill's avatar
Nick Hill committed
561
562

    # Test batch
563
    test_sampling()
Nick Hill's avatar
Nick Hill committed
564
565
566
567

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
568
                            expected_tokens, seq_lens):
Nick Hill's avatar
Nick Hill committed
569
570
571
572
573
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

574
575
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
576
    test_sampling()
Simon Mo's avatar
Simon Mo committed
577

578

579
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
580
581
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
582
583
584
585
586
587
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
588
                              device=device,
589
590
591
592
593
594
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
595
    sampler = MockLogitsSampler(fake_logits)
596
597
598
599
600

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
601
    warpers = generation_model._get_logits_warper(generation_config, device)
602
603
    assert len(warpers) == 2  # top_p and top_k

604
605
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    seq_lens: List[int] = []
606
607
608
609
610
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
611
612
613
                seq_data={
                    0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
                },
614
615
616
617
618
619
620
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
621
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
622

623
624
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
625
626
        seq_lens,
        query_lens=seq_lens,
627
        device=device,
628
        pin_memory=is_pin_memory_available())
629
630
631

    sample_probs = None

632
    def mock_sample(probs, *args, **kwargs):
633
634
        nonlocal sample_probs
        sample_probs = probs
635
636
        return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                 for prob in probs], None)
637

638
639
640
641
    # top-k and top-p is only calculated when flashinfer kernel is not available
    with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
         patch("vllm.model_executor.layers.sampler."
               "flashinfer_top_k_top_p_sampling", None):
642
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
643
644
645

    assert sample_probs is not None

646
647
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
648
    torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
649
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
650
651


652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
    if not envs.VLLM_USE_FLASHINFER_SAMPLER:
        pytest.skip("Flashinfer sampler is disabled")

    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)

    def failing_flashinfer_sampling(*_args, **_kwargs):
        return None, torch.zeros(batch_size, device=device, dtype=torch.int32)

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
    sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                sampling_params, device)

    with patch(
            "vllm.model_executor.layers.sampler."
            "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
        fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                             sampling_params, device)

    assert sampler_output == fallback_sampler_output


683
684
685
686
687
688
689
690
691
692
693
694
695
696
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

    vocab_size = 8

    def test_sampling_params(sampling_params: List[SamplingParams]):

        seq_group_metadata_list: List[SequenceGroupMetadata] = []
        seq_lens: List[int] = []
        for i in range(2):
            seq_group_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{i}",
                    is_prompt=True,
697
698
699
700
701
                    seq_data={
                        0:
                        SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                           [1, 2, 3]))
                    },
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
                    sampling_params=sampling_params[i],
                    block_tables={0: [1]},
                ))
            seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
            seq_lens,
            query_lens=seq_lens,
            device=device,
            pin_memory=is_pin_memory_available())

        fake_logits = torch.full((2, vocab_size),
                                 1e-2,
                                 device=device,
                                 dtype=torch.float16)

        fake_logits[:, 5] = 1.1e-2
        fake_logits[:, 1] = 1.2e-2

        sampler = MockLogitsSampler(fake_logits)

        sampler_output = sampler(logits=fake_logits,
                                 sampling_metadata=sampling_metadata)

        generated_tokens = []
        for output in sampler_output:
            generated_tokens.append(output.samples[0].output_token)

        return generated_tokens

    # one configuration is greedy with repetition_penalty
    sampling_params_rep = SamplingParams(
        temperature=0.0,
        repetition_penalty=2.0,
    )

    # other configuration is sampling w/o repetition_penalty
    sampling_params_sample = SamplingParams(
        temperature=1.0,
        top_k=1,
        seed=42,
    )

    tokens1 = test_sampling_params(
        [sampling_params_rep, sampling_params_sample])

    tokens2 = test_sampling_params(
        [sampling_params_sample, sampling_params_rep])

    assert tokens1[0] == tokens2[1]
    assert tokens1[1] == tokens2[0]
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
    set_random_seed(42)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    _, fake_logits, sampler = _prepare_test(batch_size)
    sampler.include_gpu_probs_tensor = True
    sampler.should_modify_greedy_probs_inplace = False

    sampling_params = SamplingParams(temperature=0)

    mock_inplace = Mock()
    with patch(
            "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
            mock_inplace):

        sampler_output = _do_sample(batch_size, fake_logits, sampler,
                                    sampling_params, device)
        mock_inplace.assert_not_called()

    assert sampler_output.sampled_token_probs is not None
    assert sampler_output.logprobs is not None
    assert sampler_output.sampled_token_ids is not None