"vscode:/vscode.git/clone" did not exist on "4f967dcc6c2d9f1ee6e9b11003428449d63acd2f"
test_sampler.py 11.5 KB
Newer Older
1
2
3
4
import random
from typing import Tuple
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import pytest
6
import torch
7
from transformers import GenerationConfig, GenerationMixin
8
9
10
11

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.worker.model_runner import ModelRunner
13
14
15
16
17
18
19
20
21
22


class MockLogitsSampler(Sampler):

    def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
        super().__init__(vocab_size=vocab_size)
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
        with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
23
24
                   lambda x, y: x), patch(
                       "vllm.model_executor.layers.sampler._get_logits",
25
                       lambda *args, **kwargs: self.fake_logits):
26
            return super().forward(*args, **kwargs)
27
28
29
30


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
31
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
32
33
34
35
36
37
38
39
40
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
                              device="cuda",
                              dtype=torch.float16)
    fake_logits = torch.full((batch_size, vocab_size),
                             1e-2,
                             device=input_tensor.device,
                             dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
    model_runner = ModelRunner(None, None, None)
    return input_tensor, fake_logits, sampler, model_runner
43
44
45
46
47
48
49
50
51


RANDOM_SEEDS = list(range(128))


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
54
55

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
56
    prompt_lens = []
57
58
59
60
61
62
63
64
65
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0, ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
67

Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
70
71
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
72
                             sampling_metadata=sampling_metadata)
73
74
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        for nth_output in sequence_output.samples:
76
77
78
79
80
81
82
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
85
86
87
88
89

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
90
    prompt_lens = []
91
92
93
94
95
96
97
98
99
100
101
102
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1.0,
                    n=random.randint(1, 10),
                ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
104

Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
107
108
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
                             sampling_metadata=sampling_metadata)
110
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        for nth_output in sequence_output.samples:
112
113
114
115
116
117
118
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
119
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
120
121

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
122
    prompt_lens = []
123
124
125
126
127
128
129
130
131
132
133
134
135
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=0,
                    best_of=2,
                    use_beam_search=True,
                ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
136
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
137

Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
140
141
    sampler(embedding=None,
            hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
            sampling_metadata=sampling_metadata)
143
144
145
146
147
148
149
150
151
152
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
155
156
157

    seq_group_metadata_list = []
    expected_tokens = []
Woosuk Kwon's avatar
Woosuk Kwon committed
158
    prompt_lens = []
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    for i in range(batch_size):
        n = 1
        sampling_type = random.randint(0, 2)
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
        elif sampling_type == 1:
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
        for idx in range(n):
            fake_logits[i, i + idx] = 1e2
            expected_tokens.append(i + idx)
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
188
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
189

Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
192
193
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
194
                             sampling_metadata=sampling_metadata)
195
196
197
    for i, sequence_output in enumerate(sampler_output):
        if seq_group_metadata_list[i].sampling_params.use_beam_search:
            continue
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        for nth_output in sequence_output.samples:
199
            assert nth_output.output_token in expected_tokens
200
201
202
203
204
205


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
206
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
207
208
209
210
211
212
213
214
215

    # This sample logits processor gives infinite score to the i-th token,
    # where i is the length of the input sequence.
    # We therefore expect the output token sequence to be [0, 1, 2, ...]
    def pick_ith(token_ids, logits):
        logits[len(token_ids)] = float("inf")
        return logits

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
216
    prompt_lens = []
217
218
219
220
221
222
223
224
225
226
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0,
                                               logits_processors=[pick_ith]),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
227
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
228

Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
231
232
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
233
                             sampling_metadata=sampling_metadata)
234
    for _, sequence_output in enumerate(sampler_output):
235
236
        for idx, nth_output in enumerate(sequence_output.samples):
            assert nth_output.output_token == idx
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_top_k_top_p(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
    top_k = random.randint(100, 500)
    top_p = random.random() * 0.1
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
                              device="cuda",
                              dtype=torch.float16)
    fake_logits = torch.normal(0,
                               5,
                               size=(batch_size, vocab_size),
                               device=input_tensor.device,
                               dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
    model_runner = ModelRunner(None, None, None)

    generation_model = GenerationMixin()
    generation_config = GenerationConfig(top_k=top_k,
                                         top_p=top_p,
                                         do_sample=True)
    warpers = generation_model._get_logits_warper(generation_config)
    assert len(warpers) == 2  # top_p and top_k

    seq_group_metadata_list = []
    prompt_lens = []
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1,
                    top_k=top_k,
                    top_p=top_p,
                ),
                block_tables={0: [1]},
            ))
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)

    sample_probs = None

    def mock_sample(probs, logprobs, sampling_metadata):
        nonlocal sample_probs
        sample_probs = probs
        return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

    with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
        sampler(embedding=None,
                hidden_states=input_tensor,
                sampling_metadata=sampling_metadata)
    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
    hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
    assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
    assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))