"vllm/vscode:/vscode.git/clone" did not exist on "5bd9b3042e0f9349cada41d9ffa8d19d209e2290"
test_sampler.py 12.7 KB
Newer Older
1
import random
Nick Hill's avatar
Nick Hill committed
2
from typing import Tuple, List
3
4
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import pytest
6
import torch
7
from transformers import GenerationConfig, GenerationMixin
Nick Hill's avatar
Nick Hill committed
8
from typing import Optional
9
10
11
12

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.worker.model_runner import ModelRunner
14
15
16
17


class MockLogitsSampler(Sampler):

18
19
    def __init__(self, fake_logits: torch.Tensor):
        super().__init__()
20
21
22
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
23
        return super().forward(*args, **kwargs)
24
25
26
27


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
28
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
29
    vocab_size = 32000
30
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
31
32
33
    fake_logits = torch.full((batch_size, vocab_size),
                             1e-2,
                             dtype=input_tensor.dtype)
34
    sampler = MockLogitsSampler(fake_logits)
35
    model_runner = ModelRunner(None, None, None, None, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
    return input_tensor, fake_logits, sampler, model_runner
37
38
39


RANDOM_SEEDS = list(range(128))
40
41
42
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
43
44


Nick Hill's avatar
Nick Hill committed
45
46
47
48
49
50
51
def _do_sample(
    batch_size: int,
    input_tensor: torch.Tensor,
    sampler: MockLogitsSampler,
    model_runner: ModelRunner,
    sampling_params: SamplingParams,
):
52
    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
53
    prompt_lens = []
54
55
56
57
58
59
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
Nick Hill's avatar
Nick Hill committed
60
                sampling_params=sampling_params,
61
62
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
63
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
64

Woosuk Kwon's avatar
Woosuk Kwon committed
65
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
66
67
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
68
    return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
Nick Hill's avatar
Nick Hill committed
69
70
71
72
73
74
75
76
77
78
79
80


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)

    sampling_params = SamplingParams(temperature=0)
81
82
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
                                sampling_params)
83
84
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        for nth_output in sequence_output.samples:
86
87
            assert nth_output.output_token == expected[i].item()

Simon Mo's avatar
Simon Mo committed
88
89
    del model_runner

90
91

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
92
93
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
94
    set_random_seed(seed)
95
    torch.set_default_device(device)
96
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
99
100
101
102

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

Nick Hill's avatar
Nick Hill committed
103
104
105
106
    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
    )
107
108
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
                                sampling_params)
Nick Hill's avatar
Nick Hill committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    for i, sequence_output in enumerate(sampler_output):
        for nth_output in sequence_output.samples:
            assert nth_output.output_token == i

    del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
123
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
124

125
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
126
127
128
129
130
131
132
        fake_logits[i, i] = 1e2

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
133
134
    sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
                                sampling_params)
135
136

    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
137
        for nth_output in sequence_output.samples:
138
139
            assert nth_output.output_token == i

Simon Mo's avatar
Simon Mo committed
140
141
    del model_runner

142

Nick Hill's avatar
Nick Hill committed
143
144
145
146
147
148
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
    set_random_seed(seed)
    torch.set_default_device(device)
    batch_size = random.randint(1, 256)
149
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
Nick Hill's avatar
Nick Hill committed
150
151
152
153
154
155

    sampling_params = SamplingParams(
        temperature=1.0,
        n=random.randint(1, 10),
        seed=random.randint(0, 10000),
    )
156
    first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
Nick Hill's avatar
Nick Hill committed
157
158
                                      model_runner, sampling_params)

159
    second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
Nick Hill's avatar
Nick Hill committed
160
161
162
163
164
165
166
                                       model_runner, sampling_params)

    assert first_sampler_output == second_sampler_output

    del model_runner


167
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
168
169
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
170
    set_random_seed(seed)
171
    torch.set_default_device(device)
172
    batch_size = random.randint(1, 256)
173
    _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
174

Nick Hill's avatar
Nick Hill committed
175
176
177
178
179
    sampling_params = SamplingParams(
        temperature=0,
        best_of=2,
        use_beam_search=True,
    )
180
    _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
181
182
183
184
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.
Simon Mo's avatar
Simon Mo committed
185
    del model_runner
186
187
188


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
189
190
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
191
    set_random_seed(seed)
192
    torch.set_default_device(device)
193
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
196
197

    seq_group_metadata_list = []
Nick Hill's avatar
Nick Hill committed
198
    expected_tokens: List[Optional[List[int]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
199
    prompt_lens = []
200
    for i in range(batch_size):
Nick Hill's avatar
Nick Hill committed
201
202
        expected: Optional[List[int]] = None
        sampling_type = random.randint(0, 3)
203
204
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
Nick Hill's avatar
Nick Hill committed
205
206
            expected = [torch.argmax(fake_logits[i], dim=-1).item()]
        elif sampling_type in (1, 2):
207
208
209
210
211
212
213
214
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
Nick Hill's avatar
Nick Hill committed
215
216
217
218
219
220
            if sampling_type == 2:
                sampling_params.seed = random.randint(0, 10000)
            else:
                for idx in range(n):
                    fake_logits[i, i + idx] = 1e2
                expected = list(range(i, i + n))
221
222
223
224
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
Nick Hill's avatar
Nick Hill committed
225
        expected_tokens.append(expected)
226
227
228
229
230
231
232
233
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
234
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
235

Nick Hill's avatar
Nick Hill committed
236
237
238
    def test_sampling(model_runner: ModelRunner):
        sampling_metadata = model_runner._prepare_sample(
            seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
239
        sampler_output = sampler(logits=fake_logits,
Nick Hill's avatar
Nick Hill committed
240
241
242
243
244
245
246
                                 sampling_metadata=sampling_metadata)

        for i, (sequence_output, metadata) in enumerate(
                zip(sampler_output, seq_group_metadata_list)):
            if metadata.sampling_params.use_beam_search:
                continue

247
248
249
250
            if (metadata.sampling_params.seed is not None
                    and expected_tokens[i] is None):
                # Record seeded random result to compare with results of
                # second invocation
Nick Hill's avatar
Nick Hill committed
251
252
253
254
255
256
257
                expected_tokens[i] = [
                    nth_output.output_token
                    for nth_output in sequence_output.samples
                ]
                continue

            for n, nth_output in enumerate(sequence_output.samples):
258
259
                if (metadata.sampling_params.temperature == 0
                        or metadata.sampling_params.seed is not None):
Nick Hill's avatar
Nick Hill committed
260
261
262
                    # Ensure exact matches for greedy or random with seed
                    assert nth_output.output_token == expected_tokens[i][n]
                else:
263
264
                    # For non-seeded random check that one of the high-logit
                    # tokens were chosen
Nick Hill's avatar
Nick Hill committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                    assert nth_output.output_token in expected_tokens[i]

    # Test batch
    test_sampling(model_runner)

    # Shuffle the batch and resample
    target_index = list(range(batch_size))
    for list_to_shuffle in (target_index, seq_group_metadata_list,
                            expected_tokens, prompt_lens):
        random.Random(seed).shuffle(list_to_shuffle)
    target_index = torch.tensor(target_index)
    input_tensor.data = input_tensor.index_select(0, target_index)
    fake_logits.data = fake_logits.index_select(0, target_index)

279
280
    # This time, results of seeded random samples will be compared with
    # the corresponding sample in the pre-shuffled batch
Nick Hill's avatar
Nick Hill committed
281
    test_sampling(model_runner)
282

Simon Mo's avatar
Simon Mo committed
283
284
    del model_runner

285

286
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
287
288
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
289
290
291
292
293
294
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
295
                              device=device,
296
297
298
299
300
301
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
302
    sampler = MockLogitsSampler(fake_logits)
303
    model_runner = ModelRunner(None, None, None, None, None)
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
    prompt_lens = []
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
330
331
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
332
333
334

    sample_probs = None

335
    def mock_sample(probs, *args, **kwargs):
336
337
338
339
340
        nonlocal sample_probs
        sample_probs = probs
        return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
341
        sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
342
343
344
345
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
Simon Mo's avatar
Simon Mo committed
346
347

    del model_runner