test_sampler.py 24.1 KB
Newer Older
1
import itertools
2
import random
3
from typing import List, Optional, Tuple
4
5
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
6
import pytest
7
import torch
8
from transformers import GenerationConfig, GenerationMixin
9
10

from vllm.model_executor.layers.sampler import Sampler
11
from vllm.model_executor.sampling_metadata import SamplingMetadata
12
13
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
14
from vllm.utils import Counter
15
from vllm.worker.model_runner import ModelRunner
16
17
18
19


class MockLogitsSampler(Sampler):

20
21
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
22
23
24
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
25
        return super().forward(*args, **kwargs)
26
27
28
29


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
30
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
31
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
32
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
33
34
                             1e-2,
                             dtype=input_tensor.dtype)
35
    sampler = MockLogitsSampler(fake_logits)
36
37
38
39
40
41
    model_runner = ModelRunner(model_config=None,
                               parallel_config=None,
                               scheduler_config=None,
                               device_config=None,
                               load_config=None,
                               lora_config=None)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
    return input_tensor, fake_logits, sampler, model_runner
43
44


45
VOCAB_SIZE = 32000
46
RANDOM_SEEDS = list(range(128))
47
48
49
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
50
51


Nick Hill's avatar
Nick Hill committed
52
53
54
55
56
57
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    model_runner: ModelRunner,
    sampling_params: SamplingParams,
58
    device: str,
Nick Hill's avatar
Nick Hill committed
59
):
60
    seq_group_metadata_list = []
61
    seq_lens = []
62
63
64
65
66
67
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
68
                sampling_params=sampling_params,
69
70
                block_tables={0: [1]},
            ))
71
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
72

73
74
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
75
76
        seq_lens,
        query_lens=seq_lens,
77
78
        device=device,
        pin_memory=model_runner.pin_memory)
79
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
80
81
82
83
84
85
86
87
88
89
90
91


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)

    sampling_params = SamplingParams(temperature=0)
92
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
93
                                sampling_params, device)
94
95
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        for nth_output in sequence_output.samples:
97
98
            assert nth_output.output_token == expected[i].item()

Simon Mo's avatar
Simon Mo committed
99
100
    del model_runner

101
102

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
103
104
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
105
    set_random_seed(seed)
106
    torch.set_default_device(device)
107
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
110
111
112
113

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
114
115
116
117
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
118
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
119
                                sampling_params, device)
Nick Hill's avatar
Nick Hill committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i

    del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
134
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
135

136
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
137
138
139
140
141
142
143
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
144
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
145
                                sampling_params, device)
146
147

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
148
        for nth_output in sequence_output.samples:
149
150
            assert nth_output.output_token == i

Simon Mo's avatar
Simon Mo committed
151
152
    del model_runner

153

Nick Hill's avatar
Nick Hill committed
154
155
156
157
158
159
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
160
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
161
162
163
164
165
166

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
167
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
168
                                      model_runner, sampling_params, device)
Nick Hill's avatar
Nick Hill committed
169

170
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
171
                                       model_runner, sampling_params, device)
Nick Hill's avatar
Nick Hill committed
172
173
174
175
176
177

    assert first_sampler_output == second_sampler_output

    del model_runner


178
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
179
180
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
181
    set_random_seed(seed)
182
    torch.set_default_device(device)
183
    batch_size = random.randint(1, 256)
184
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
185

Nick Hill's avatar
Nick Hill committed
186
187
188
189
190
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
191
192
    _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
               device)
193
194
195
196
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.
Simon Mo's avatar
Simon Mo committed
197
    del model_runner
198
199


200
201
202
203
204
205
206
207
208
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
209
                               *,
210
                               stop_token_ids: Optional[List[int]] = None,
211
                               prompt_logprobs: Optional[int] = None):
212
213
214
215
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
216
217
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
218
        )
219
        sampling_params.all_stop_token_ids.add(eos_token_id)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
        seq_data = SequenceData(
            random.choices(range(0, VOCAB_SIZE), k=num_input))
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
        sequence_metadata_list = []
236
237
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

            seq_data = {}
            seq_group_penalization = []
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
259
                num_generated = 0 if is_prompt else random.randint(1, 100)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
387
                is_prompt=False,
388
                seq_data={
389
390
391
392
393
394
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
395
396
                },
                sampling_params=create_sampling_params(
397
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
398
                block_tables={},
399
            ),
400
401
402
403
404
405
406
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
407
            prompt_with_penalization_and_prompt_logprobs,
408
            stop_penalizing_after_min_tokens,
409
410
            prompt_combination,
            decode_combination,
411
412
413
414
415
416
417
        ]
    else:
        test_cases = [generate_test_case()]

    def run_test_case(*,
                      expected_penalization=None,
                      seq_group_metadata_list=None):
418
419
420
421
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
422
423

        batch_size = 0
424
        seq_lens = []
425
        sampling_params_per_row = []
426
427
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
428
429
430
431
432
433

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
434
                seq_lens.append(prompt_len)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
450
451

        _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
452
        sampling_metadata = SamplingMetadata.prepare(
453
            seq_group_metadata_list,
454
455
            seq_lens=seq_lens if seq_lens else None,
            query_lens=seq_lens if seq_lens else None,
456
457
            device=device,
            pin_memory=model_runner.pin_memory)
458
459
460
461
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
462
                zip(expected_penalization, sampling_params_per_row)):
463

464
            tokens_to_check = sampling_params.all_stop_token_ids
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

        del model_runner

    for test_case in test_cases:
        run_test_case(**test_case)


489
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
490
491
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
492
    set_random_seed(seed)
493
    torch.set_default_device(device)
494
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
495
496
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
497
498

    seq_group_metadata_list = []
Nick Hill's avatar
Nick Hill committed
499
    expected_tokens: List[Optional[List[int]]] = []
500
    seq_lens = []
501
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
502
503
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
504
505
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
Nick Hill's avatar
Nick Hill committed
506
507
            expected = [torch.argmax(fake_logits[i], dim=-1).item()]
        elif sampling_type in (1, 2):
508
509
510
511
512
513
514
515
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
516
517
518
519
520
521
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
522
523
524
525
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
526
        expected_tokens.append(expected)
527
528
529
530
531
532
533
534
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
535
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
536

Nick Hill's avatar
Nick Hill committed
537
    def test_sampling(model_runner: ModelRunner):
538
539
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
540
541
            seq_lens,
            query_lens=seq_lens,
542
543
            device=device,
            pin_memory=model_runner.pin_memory)
544
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
545
546
547
548
549
550
551
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
            if metadata.sampling_params.use_beam_search:
                continue

552
553
554
555
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
556
557
558
559
560
561
562
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

            for n, nth_output in enumerate(sequence_output.samples):
563
564
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
565
566
567
                    # Ensure exact matches for greedy or random with seed
                    assert nth_output.output_token == expected_tokens[i][n]
                else:
568
569
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
Nick Hill's avatar
Nick Hill committed
570
571
572
573
574
575
576
577
                    assert nth_output.output_token in expected_tokens[i]

    # Test batch
    test_sampling(model_runner)

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
578
                            expected_tokens, seq_lens):
Nick Hill's avatar
Nick Hill committed
579
580
581
582
583
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

584
585
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
Nick Hill's avatar
Nick Hill committed
586
    test_sampling(model_runner)
587

Simon Mo's avatar
Simon Mo committed
588
589
    del model_runner

590

591
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
592
593
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
594
595
596
597
598
599
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
600
                              device=device,
601
602
603
604
605
606
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
607
    sampler = MockLogitsSampler(fake_logits)
608
609
610
611
612
613
    model_runner = ModelRunner(model_config=None,
                               parallel_config=None,
                               scheduler_config=None,
                               device_config=None,
                               load_config=None,
                               lora_config=None)
614
615
616
617
618
619
620
621
622

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
623
    seq_lens = []
624
625
626
627
628
629
630
631
632
633
634
635
636
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
637
        seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
638

639
640
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
641
642
        seq_lens,
        query_lens=seq_lens,
643
644
        device=device,
        pin_memory=model_runner.pin_memory)
645
646
647

    sample_probs = None

648
    def mock_sample(probs, *args, **kwargs):
649
650
        nonlocal sample_probs
        sample_probs = probs
651
652
        return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                 for prob in probs], None)
653
654

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
655
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
656
657
658
659
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
Simon Mo's avatar
Simon Mo committed
660
661

    del model_runner