test_sampler.py 23.5 KB
Newer Older
1
import itertools
2
import random
3
from typing import List, Optional, Tuple
4
5
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
6
import pytest
7
import torch
8
from transformers import GenerationConfig, GenerationMixin
9
10
11
12

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
13
from vllm.utils import Counter
14
from vllm.worker.model_runner import ModelRunner
15
16
17
18


class MockLogitsSampler(Sampler):

19
20
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
21
22
23
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
24
        return super().forward(*args, **kwargs)
25
26
27
28


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
29
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
30
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
31
    fake_logits = torch.full((batch_size, VOCAB_SIZE),
32
33
                             1e-2,
                             dtype=input_tensor.dtype)
34
    sampler = MockLogitsSampler(fake_logits)
35
    model_runner = ModelRunner(None, None, None, None, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
    return input_tensor, fake_logits, sampler, model_runner
37
38


39
VOCAB_SIZE = 32000
40
RANDOM_SEEDS = list(range(128))
41
42
43
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
44
45


Nick Hill's avatar
Nick Hill committed
46
47
48
49
50
51
52
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    model_runner: ModelRunner,
    sampling_params: SamplingParams,
):
53
    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
54
    prompt_lens = []
55
56
57
58
59
60
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
61
                sampling_params=sampling_params,
62
63
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
64
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
67
68
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
69
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
70
71
72
73
74
75
76
77
78
79
80
81


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)

    sampling_params = SamplingParams(temperature=0)
82
83
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
                                sampling_params)
84
85
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        for nth_output in sequence_output.samples:
87
88
            assert nth_output.output_token == expected[i].item()

Simon Mo's avatar
Simon Mo committed
89
90
    del model_runner

91
92

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
93
94
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
95
    set_random_seed(seed)
96
    torch.set_default_device(device)
97
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
100
101
102
103

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
104
105
106
107
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
108
109
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
                                sampling_params)
Nick Hill's avatar
Nick Hill committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i

    del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
124
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
125

126
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
127
128
129
130
131
132
133
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
134
135
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
                                sampling_params)
136
137

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
138
        for nth_output in sequence_output.samples:
139
140
            assert nth_output.output_token == i

Simon Mo's avatar
Simon Mo committed
141
142
    del model_runner

143

Nick Hill's avatar
Nick Hill committed
144
145
146
147
148
149
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
150
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
151
152
153
154
155
156

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
157
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
Nick Hill's avatar
Nick Hill committed
158
159
                                      model_runner, sampling_params)

160
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
Nick Hill's avatar
Nick Hill committed
161
162
163
164
165
166
167
                                       model_runner, sampling_params)

    assert first_sampler_output == second_sampler_output

    del model_runner


168
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
169
170
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
171
    set_random_seed(seed)
172
    torch.set_default_device(device)
173
    batch_size = random.randint(1, 256)
174
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
175

Nick Hill's avatar
Nick Hill committed
176
177
178
179
180
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
181
    _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
182
183
184
185
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.
Simon Mo's avatar
Simon Mo committed
186
    del model_runner
187
188


189
190
191
192
193
194
195
196
197
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
    seq_id_counter = Counter(start=random.randint(0, 100))
    set_random_seed(seed)
    torch.set_default_device(device)

    def create_sampling_params(min_tokens,
                               eos_token_id=0,
198
199
200
                               *,
                               stop_token_ids: Optional[List[str]] = None,
                               prompt_logprobs: Optional[int] = None):
201
202
203
204
        sampling_params = SamplingParams(
            min_tokens=min_tokens,
            max_tokens=9999,  # keep higher than max of min_tokens
            stop_token_ids=stop_token_ids,
205
206
            # requesting prompt_logprobs changes the structure of `logits`
            prompt_logprobs=prompt_logprobs,
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        )
        sampling_params.eos_token_id = eos_token_id
        return sampling_params

    def create_sequence_data(num_input=3, num_generated=0):
        seq_data = SequenceData(
            random.choices(range(0, VOCAB_SIZE), k=num_input))
        if num_generated > 0:
            seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
                                                       k=num_generated)
        return seq_data

    def generate_test_case():
        # generate multiple seq groups but limit total batch size
        batch_size = random.randint(1, 128)

        expected_penalization = []
        sequence_metadata_list = []
225
226
        # 20% chance to generate seq group metadata list with all prompts
        is_prompt = random.random() < 0.2
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        while batch_size > 0:
            num_seqs = 1 if is_prompt else random.randint(1, batch_size)

            eos_token_id = random.randint(0, VOCAB_SIZE - 1)
            min_tokens = random.randint(0, 50)
            num_stop_tokens = random.randint(0, 8)
            if num_stop_tokens > 0:
                stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
                                                k=num_stop_tokens)
            else:
                stop_token_ids = None

            sampling_params = create_sampling_params(
                min_tokens=min_tokens,
                eos_token_id=eos_token_id,
                stop_token_ids=stop_token_ids)

            seq_data = {}
            seq_group_penalization = []
            for _ in range(num_seqs):
                num_input = random.randint(1, 100)
248
                num_generated = 0 if is_prompt else random.randint(1, 100)
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                seq_data[next(seq_id_counter)] = create_sequence_data(
                    num_input=num_input, num_generated=num_generated)
                seq_group_penalization.append(num_generated < min_tokens)

            expected_penalization.extend(seq_group_penalization)
            sequence_metadata_list.append(
                SequenceGroupMetadata(
                    request_id=f"test_{batch_size}",
                    is_prompt=is_prompt,
                    seq_data=seq_data,
                    sampling_params=sampling_params,
                    block_tables={},
                ))
            batch_size -= num_seqs

        return {
            "expected_penalization": expected_penalization,
            "seq_group_metadata_list": sequence_metadata_list,
        }

    # define some explicit test cases for edge case behavior
    prompt_without_penalization = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(0),
                block_tables={},
            ),
        ]
    }

    prompt_with_penalization = {
        "expected_penalization": [True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            ),
        ]
    }

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    prompt_with_penalization_and_prompt_logprobs = {
        "expected_penalization": [False, False, True],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=3),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
        ]
    }

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    stop_penalizing_after_min_tokens = {
        "expected_penalization": [False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                },
                sampling_params=create_sampling_params(1),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [42, 99, 42, 0]  # intentional duplication
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    prompt_combination = {
        "expected_penalization": [False, True, False],
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_2",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(num_input=2),
                },
                sampling_params=create_sampling_params(1, prompt_logprobs=3),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_3",
                is_prompt=True,
                seq_data={
                    next(seq_id_counter): create_sequence_data(),
                },
                sampling_params=create_sampling_params(
                    0, stop_token_ids=stop_token_ids),
                block_tables={},
            )
        ]
    }

    stop_token_ids = [1, 999, 37, 37]  # intentional duplication
    decode_combination = {
        "expected_penalization": [True, False, False, True, False],
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        "seq_group_metadata_list": [
            SequenceGroupMetadata(
                request_id="test_1",
                is_prompt=False,
                seq_data={
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=100),
                },
                sampling_params=create_sampling_params(
                    2, stop_token_ids=stop_token_ids),
                block_tables={},
            ),
            SequenceGroupMetadata(
                request_id="test_2",
376
                is_prompt=False,
377
                seq_data={
378
379
380
381
382
383
                    next(seq_id_counter):
                    create_sequence_data(num_generated=20),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=1),
                    next(seq_id_counter):
                    create_sequence_data(num_generated=10),
384
385
                },
                sampling_params=create_sampling_params(
386
                    10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
387
                block_tables={},
388
            ),
389
390
391
392
393
394
395
        ]
    }

    if seed == 0:
        test_cases = [
            prompt_without_penalization,
            prompt_with_penalization,
396
            prompt_with_penalization_and_prompt_logprobs,
397
            stop_penalizing_after_min_tokens,
398
399
            prompt_combination,
            decode_combination,
400
401
402
403
404
405
406
        ]
    else:
        test_cases = [generate_test_case()]

    def run_test_case(*,
                      expected_penalization=None,
                      seq_group_metadata_list=None):
407
408
409
410
        assert expected_penalization, \
            "Invalid test case, need expected_penalization"
        assert seq_group_metadata_list, \
            "Invalid test case, need seq_group_metadata_list"
411
412
413

        batch_size = 0
        prompt_lens = []
414
        sampling_params_per_row = []
415
416
        for sgm in seq_group_metadata_list:
            sampling_params = sgm.sampling_params
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

            num_rows = len(sgm.seq_data)
            if sgm.is_prompt:
                # a prompt seq_group has only one sequence
                seq_data = next(iter(sgm.seq_data.values()))
                prompt_len = seq_data.get_prompt_len()
                prompt_lens.append(prompt_len)

                if sgm.sampling_params.prompt_logprobs:
                    # with prompt_logprobs each token in the prompt has a row in
                    # logits
                    num_rows = prompt_len

            batch_size += num_rows
            sampling_params_per_row.extend(
                itertools.repeat(sampling_params, num_rows))

        assert len(
            expected_penalization
        ) == batch_size, \
            ("Invalid test case, expected_penalization does not match computed"
             "batch size")
439
440
441
442

        _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
        sampling_metadata = model_runner._prepare_sample(
            seq_group_metadata_list,
443
444
            prompt_lens=prompt_lens if prompt_lens else None,
            subquery_lens=prompt_lens if prompt_lens else None)
445
446
447
448
        # the logits tensor is modified in-place by the sampler
        _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

        for logits_idx, (should_penalize, sampling_params) in enumerate(
449
                zip(expected_penalization, sampling_params_per_row)):
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478

            tokens_to_check = [sampling_params.eos_token_id]
            if sampling_params.stop_token_ids:
                tokens_to_check.extend(sampling_params.stop_token_ids)
            tokens_to_check = set(tokens_to_check)

            if should_penalize:
                for token_id in tokens_to_check:
                    assert fake_logits[logits_idx, token_id] == -float(
                        'inf'
                    ), f"Expected token {token_id} for logits row {logits_idx}"
                    " to be penalized"
                # no other tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] == -float('inf')) == len(
                        tokens_to_check
                    ), f"Expected only {len(tokens_to_check)} to be penalized"
            else:
                # no tokens should be set to -inf
                assert torch.count_nonzero(
                    fake_logits[logits_idx, :] ==
                    -float('inf')) == 0, "No tokens should have been penalized"

        del model_runner

    for test_case in test_cases:
        run_test_case(**test_case)


479
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
480
481
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
482
    set_random_seed(seed)
483
    torch.set_default_device(device)
484
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
485
486
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
487
488

    seq_group_metadata_list = []
Nick Hill's avatar
Nick Hill committed
489
    expected_tokens: List[Optional[List[int]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
490
    prompt_lens = []
491
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
492
493
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
494
495
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
Nick Hill's avatar
Nick Hill committed
496
497
            expected = [torch.argmax(fake_logits[i], dim=-1).item()]
        elif sampling_type in (1, 2):
498
499
500
501
502
503
504
505
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
506
507
508
509
510
511
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
512
513
514
515
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
516
        expected_tokens.append(expected)
517
518
519
520
521
522
523
524
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
525
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
526

Nick Hill's avatar
Nick Hill committed
527
528
529
    def test_sampling(model_runner: ModelRunner):
        sampling_metadata = model_runner._prepare_sample(
            seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
530
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
531
532
533
534
535
536
537
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
            if metadata.sampling_params.use_beam_search:
                continue

538
539
540
541
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
542
543
544
545
546
547
548
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

            for n, nth_output in enumerate(sequence_output.samples):
549
550
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
551
552
553
                    # Ensure exact matches for greedy or random with seed
                    assert nth_output.output_token == expected_tokens[i][n]
                else:
554
555
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
Nick Hill's avatar
Nick Hill committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
                    assert nth_output.output_token in expected_tokens[i]

    # Test batch
    test_sampling(model_runner)

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
                            expected_tokens, prompt_lens):
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

570
571
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
Nick Hill's avatar
Nick Hill committed
572
    test_sampling(model_runner)
573

Simon Mo's avatar
Simon Mo committed
574
575
    del model_runner

576

577
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
578
579
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
580
581
582
583
584
585
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
586
                              device=device,
587
588
589
590
591
592
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
593
    sampler = MockLogitsSampler(fake_logits)
594
    model_runner = ModelRunner(None, None, None, None, None)
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
    prompt_lens = []
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
621
622
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
623
624
625

    sample_probs = None

626
    def mock_sample(probs, *args, **kwargs):
627
628
629
630
631
        nonlocal sample_probs
        sample_probs = probs
        return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
632
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
633
634
635
636
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
Simon Mo's avatar
Simon Mo committed
637
638

    del model_runner