test_sampler.py 12.6 KB
Newer Older
1
2
3
4
import random
from typing import Tuple
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import pytest
6
import torch
7
from transformers import GenerationConfig, GenerationMixin
8
9
10
11

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.worker.model_runner import ModelRunner
13
14
15
16
17
18
19
20
21


class MockLogitsSampler(Sampler):

    def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
        super().__init__(vocab_size=vocab_size)
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
22
23
24
25
26
        with patch(
                "vllm.model_executor.layers.sampler._prune_hidden_states",
                lambda x, y: x), patch(
                    "vllm.model_executor.layers.sampler.Sampler._get_logits",
                    lambda *args, **kwargs: self.fake_logits):
27
            return super().forward(*args, **kwargs)
28
29
30
31


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
32
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
33
    vocab_size = 32000
34
    input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
35
36
37
38
    fake_logits = torch.full((batch_size, vocab_size),
                             1e-2,
                             dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
39
    model_runner = ModelRunner(None, None, None, None, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
40
    return input_tensor, fake_logits, sampler, model_runner
41
42
43


RANDOM_SEEDS = list(range(128))
44
45
46
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
47
48
49


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
50
51
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
52
    set_random_seed(seed)
53
    torch.set_default_device(device)
54
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
57
58

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
59
    prompt_lens = []
60
61
62
63
64
65
66
67
68
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0, ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
69
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
72
73
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
74
75
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
76
                             sampling_metadata=sampling_metadata)
77
78
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        for nth_output in sequence_output.samples:
80
81
            assert nth_output.output_token == expected[i].item()

Simon Mo's avatar
Simon Mo committed
82
83
    del model_runner

84
85

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
86
87
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
88
    set_random_seed(seed)
89
    torch.set_default_device(device)
90
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
93
94
95
96
97

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
98
    prompt_lens = []
99
100
101
102
103
104
105
106
107
108
109
110
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1.0,
                    n=random.randint(1, 10),
                ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
112

Woosuk Kwon's avatar
Woosuk Kwon committed
113
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
114
115
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
116
117
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
                             sampling_metadata=sampling_metadata)
119
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
120
        for nth_output in sequence_output.samples:
121
122
            assert nth_output.output_token == i

Simon Mo's avatar
Simon Mo committed
123
124
    del model_runner

125
126

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
127
128
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
129
    set_random_seed(seed)
130
    torch.set_default_device(device)
131
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
132
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
133
134

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
135
    prompt_lens = []
136
137
138
139
140
141
142
143
144
145
146
147
148
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=0,
                    best_of=2,
                    use_beam_search=True,
                ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
150

Woosuk Kwon's avatar
Woosuk Kwon committed
151
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
152
153
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
154
155
    sampler(embedding=None,
            hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
156
            sampling_metadata=sampling_metadata)
157
158
159
160
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.
Simon Mo's avatar
Simon Mo committed
161
    del model_runner
162
163
164


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
165
166
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
167
    set_random_seed(seed)
168
    torch.set_default_device(device)
169
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
172
173
174

    seq_group_metadata_list = []
    expected_tokens = []
Woosuk Kwon's avatar
Woosuk Kwon committed
175
    prompt_lens = []
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    for i in range(batch_size):
        n = 1
        sampling_type = random.randint(0, 2)
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
        elif sampling_type == 1:
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
        for idx in range(n):
            fake_logits[i, i + idx] = 1e2
            expected_tokens.append(i + idx)
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
205
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
206

Woosuk Kwon's avatar
Woosuk Kwon committed
207
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
208
209
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
210
211
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
212
                             sampling_metadata=sampling_metadata)
213
214
215
    for i, sequence_output in enumerate(sampler_output):
        if seq_group_metadata_list[i].sampling_params.use_beam_search:
            continue
Woosuk Kwon's avatar
Woosuk Kwon committed
216
        for nth_output in sequence_output.samples:
217
            assert nth_output.output_token in expected_tokens
218

Simon Mo's avatar
Simon Mo committed
219
220
    del model_runner

221
222

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
223
224
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
225
    set_random_seed(seed)
226
    torch.set_default_device(device)
227
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
228
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
229
230
231
232
233
234
235
236
237

    # This sample logits processor gives infinite score to the i-th token,
    # where i is the length of the input sequence.
    # We therefore expect the output token sequence to be [0, 1, 2, ...]
    def pick_ith(token_ids, logits):
        logits[len(token_ids)] = float("inf")
        return logits

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
238
    prompt_lens = []
239
240
241
242
243
244
245
246
247
248
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0,
                                               logits_processors=[pick_ith]),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
249
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
250

Woosuk Kwon's avatar
Woosuk Kwon committed
251
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
252
253
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
254
255
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
256
                             sampling_metadata=sampling_metadata)
257
    for _, sequence_output in enumerate(sampler_output):
258
259
        for idx, nth_output in enumerate(sequence_output.samples):
            assert nth_output.output_token == idx
260

Simon Mo's avatar
Simon Mo committed
261
262
    del model_runner

263
264

@pytest.mark.parametrize("seed", RANDOM_SEEDS)
265
266
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
267
268
269
270
271
272
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
273
                              device=device,
274
275
276
277
278
279
280
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
281
    model_runner = ModelRunner(None, None, None, None, None)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
    prompt_lens = []
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
308
309
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

    sample_probs = None

    def mock_sample(probs, logprobs, sampling_metadata):
        nonlocal sample_probs
        sample_probs = probs
        return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
        sampler(embedding=None,
                hidden_states=input_tensor,
                sampling_metadata=sampling_metadata)
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
Simon Mo's avatar
Simon Mo committed
326
327

    del model_runner