test_sampler.py 15 KB
Newer Older
1
import random
Nick Hill's avatar
Nick Hill committed
2
from typing import Tuple, List
3
4
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import pytest
6
import torch
7
from transformers import GenerationConfig, GenerationMixin
Nick Hill's avatar
Nick Hill committed
8
from typing import Optional
9
10
11
12

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.worker.model_runner import ModelRunner
14
15
16
17
18
19
20
21
22


class MockLogitsSampler(Sampler):

    def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
        super().__init__(vocab_size=vocab_size)
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
23
24
25
26
27
        with patch(
                "vllm.model_executor.layers.sampler._prune_hidden_states",
                lambda x, y: x), patch(
                    "vllm.model_executor.layers.sampler.Sampler._get_logits",
                    lambda *args, **kwargs: self.fake_logits):
28
            return super().forward(*args, **kwargs)
29
30
31
32


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
33
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
34
    vocab_size = 32000
35
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
36
37
38
39
    fake_logits = torch.full((batch_size, vocab_size),
                             1e-2,
                             dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
40
    model_runner = ModelRunner(None, None, None, None, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    return input_tensor, fake_logits, sampler, model_runner
42
43
44


RANDOM_SEEDS = list(range(128))
45
46
47
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
48
49


Nick Hill's avatar
Nick Hill committed
50
51
52
53
54
55
56
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    model_runner: ModelRunner,
    sampling_params: SamplingParams,
):
57
    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
58
    prompt_lens = []
59
60
61
62
63
64
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
65
                sampling_params=sampling_params,
66
67
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
69

Woosuk Kwon's avatar
Woosuk Kwon committed
70
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
71
72
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
Nick Hill's avatar
Nick Hill committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    return sampler(embedding=None,
                   hidden_states=input_tensor,
                   sampling_metadata=sampling_metadata)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)

    sampling_params = SamplingParams(temperature=0)
    sampler_output = _do_sample(batch_size, input_tensor, sampler,
                                model_runner, sampling_params)
90
91
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        for nth_output in sequence_output.samples:
93
94
            assert nth_output.output_token == expected[i].item()

Simon Mo's avatar
Simon Mo committed
95
96
    del model_runner

97
98

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
99
100
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
101
    set_random_seed(seed)
102
    torch.set_default_device(device)
103
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
106
107
108
109

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
    sampler_output = _do_sample(batch_size, input_tensor, sampler,
                                model_runner, sampling_params)

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i

    del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)

133
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
134
135
136
137
138
139
140
141
142
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
    sampler_output = _do_sample(batch_size, input_tensor, sampler,
                                model_runner, sampling_params)
143
144

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        for nth_output in sequence_output.samples:
146
147
            assert nth_output.output_token == i

Simon Mo's avatar
Simon Mo committed
148
149
    del model_runner

150

Nick Hill's avatar
Nick Hill committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
    first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
                                      model_runner, sampling_params)

    second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
                                       model_runner, sampling_params)

    assert first_sampler_output == second_sampler_output

    del model_runner


176
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
177
178
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
179
    set_random_seed(seed)
180
    torch.set_default_device(device)
181
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
182
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
183

Nick Hill's avatar
Nick Hill committed
184
185
186
187
188
189
190
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
    _do_sample(batch_size, input_tensor, sampler, model_runner,
               sampling_params)
191
192
193
194
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.
Simon Mo's avatar
Simon Mo committed
195
    del model_runner
196
197
198


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
199
200
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
201
    set_random_seed(seed)
202
    torch.set_default_device(device)
203
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
206
207

    seq_group_metadata_list = []
Nick Hill's avatar
Nick Hill committed
208
    expected_tokens: List[Optional[List[int]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
209
    prompt_lens = []
210
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
211
212
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
213
214
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
Nick Hill's avatar
Nick Hill committed
215
216
            expected = [torch.argmax(fake_logits[i], dim=-1).item()]
        elif sampling_type in (1, 2):
217
218
219
220
221
222
223
224
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
225
226
227
228
229
230
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
231
232
233
234
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
235
        expected_tokens.append(expected)
236
237
238
239
240
241
242
243
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
244
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
245

Nick Hill's avatar
Nick Hill committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    def test_sampling(model_runner: ModelRunner):
        sampling_metadata = model_runner._prepare_sample(
            seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
        sampler_output = sampler(embedding=None,
                                 hidden_states=input_tensor,
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
            if metadata.sampling_params.use_beam_search:
                continue

            if metadata.sampling_params.seed is not None \
                    and expected_tokens[i] is None:
                # Record seeded random result to compare with results of second invocation
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

            for n, nth_output in enumerate(sequence_output.samples):
                if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
                    # Ensure exact matches for greedy or random with seed
                    assert nth_output.output_token == expected_tokens[i][n]
                else:
                    # For non-seeded random check that one of the high-logit tokens were chosen
                    assert nth_output.output_token in expected_tokens[i]

    # Test batch
    test_sampling(model_runner)

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
                            expected_tokens, prompt_lens):
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

    # This time, results of seeded random samples will be compared with the corresponding
    # sample in the pre-shuffled batch
    test_sampling(model_runner)
290

Simon Mo's avatar
Simon Mo committed
291
292
    del model_runner

293
294

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
295
296
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
297
    set_random_seed(seed)
298
    torch.set_default_device(device)
299
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
300
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
301
302
303
304
305
306
307
308
309

    # This sample logits processor gives infinite score to the i-th token,
    # where i is the length of the input sequence.
    # We therefore expect the output token sequence to be [0, 1, 2, ...]
    def pick_ith(token_ids, logits):
        logits[len(token_ids)] = float("inf")
        return logits

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
310
    prompt_lens = []
311
312
313
314
315
316
317
318
319
320
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0,
                                               logits_processors=[pick_ith]),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
322

Woosuk Kwon's avatar
Woosuk Kwon committed
323
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
324
325
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
326
327
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
328
                             sampling_metadata=sampling_metadata)
329
    for _, sequence_output in enumerate(sampler_output):
330
331
        for idx, nth_output in enumerate(sequence_output.samples):
            assert nth_output.output_token == idx
332

Simon Mo's avatar
Simon Mo committed
333
334
    del model_runner

335
336

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
337
338
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
339
340
341
342
343
344
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
345
                              device=device,
346
347
348
349
350
351
352
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
353
    model_runner = ModelRunner(None, None, None, None, None)
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
    prompt_lens = []
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
380
381
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

    sample_probs = None

    def mock_sample(probs, logprobs, sampling_metadata):
        nonlocal sample_probs
        sample_probs = probs
        return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
        sampler(embedding=None,
                hidden_states=input_tensor,
                sampling_metadata=sampling_metadata)
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
Simon Mo's avatar
Simon Mo committed
398
399

    del model_runner