test_sampler.py 8.99 KB
Newer Older
1
2
3
4
import random
from typing import Tuple
from unittest.mock import patch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import pytest
6
7
8
9
10
import torch

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.worker.model_runner import ModelRunner
12
13
14
15
16
17
18
19
20
21


class MockLogitsSampler(Sampler):

    def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
        super().__init__(vocab_size=vocab_size)
        self.fake_logits = fake_logits

    def forward(self, *args, **kwargs):
        with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
22
23
                   lambda x, y: x), patch(
                       "vllm.model_executor.layers.sampler._get_logits",
24
                       lambda *args, **kwargs: self.fake_logits):
25
            return super().forward(*args, **kwargs)
26
27
28
29


def _prepare_test(
    batch_size: int
Woosuk Kwon's avatar
Woosuk Kwon committed
30
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
31
32
33
34
35
36
37
38
39
    vocab_size = 32000
    input_tensor = torch.rand((batch_size, 1024),
                              device="cuda",
                              dtype=torch.float16)
    fake_logits = torch.full((batch_size, vocab_size),
                             1e-2,
                             device=input_tensor.device,
                             dtype=input_tensor.dtype)
    sampler = MockLogitsSampler(32000, fake_logits)
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
    model_runner = ModelRunner(None, None, None)
    return input_tensor, fake_logits, sampler, model_runner
42
43
44
45
46
47
48
49
50


RANDOM_SEEDS = list(range(128))


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
53
54

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
55
    prompt_lens = []
56
57
58
59
60
61
62
63
64
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0, ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
65
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
66

Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
69
70
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
71
                             sampling_metadata=sampling_metadata)
72
73
    expected = torch.argmax(fake_logits, dim=-1)
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        for nth_output in sequence_output.samples:
75
76
77
78
79
80
81
            assert nth_output.output_token == expected[i].item()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
84
85
86
87
88

    for i in range(batch_size):
        fake_logits[i, i] = 1e2

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
89
    prompt_lens = []
90
91
92
93
94
95
96
97
98
99
100
101
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=1.0,
                    n=random.randint(1, 10),
                ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
102
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
103

Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
106
107
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
108
                             sampling_metadata=sampling_metadata)
109
    for i, sequence_output in enumerate(sampler_output):
Woosuk Kwon's avatar
Woosuk Kwon committed
110
        for nth_output in sequence_output.samples:
111
112
113
114
115
116
117
            assert nth_output.output_token == i


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
119
120

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
121
    prompt_lens = []
122
123
124
125
126
127
128
129
130
131
132
133
134
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(
                    temperature=0,
                    best_of=2,
                    use_beam_search=True,
                ),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
136

Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
139
140
    sampler(embedding=None,
            hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
            sampling_metadata=sampling_metadata)
142
143
144
145
146
147
148
149
150
151
    # no assertion here as I am not sure how to determine whether
    # the outputs are expected - in other words, this just tests
    # whether there are no exceptions in the sampler
    # when handling an all-beam search case.


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
        batch_size)
154
155
156

    seq_group_metadata_list = []
    expected_tokens = []
Woosuk Kwon's avatar
Woosuk Kwon committed
157
    prompt_lens = []
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    for i in range(batch_size):
        n = 1
        sampling_type = random.randint(0, 2)
        if sampling_type == 0:
            sampling_params = SamplingParams(temperature=0)
        elif sampling_type == 1:
            n = random.randint(1, 10)
            sampling_params = SamplingParams(
                temperature=random.random() + 0.1,
                top_p=min(random.random() + 0.1, 1),
                top_k=random.randint(0, 10) or -1,
                n=n,
                presence_penalty=random.randint(0, 1),
            )
        else:
            sampling_params = SamplingParams(temperature=0,
                                             use_beam_search=True,
                                             best_of=2)
        for idx in range(n):
            fake_logits[i, i + idx] = 1e2
            expected_tokens.append(i + idx)
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=sampling_params,
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
187
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
188

Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
191
192
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
193
                             sampling_metadata=sampling_metadata)
194
195
196
    for i, sequence_output in enumerate(sampler_output):
        if seq_group_metadata_list[i].sampling_params.use_beam_search:
            continue
Woosuk Kwon's avatar
Woosuk Kwon committed
197
        for nth_output in sequence_output.samples:
198
            assert nth_output.output_token in expected_tokens
199
200
201
202
203
204


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
    set_random_seed(seed)
    batch_size = random.randint(1, 256)
Woosuk Kwon's avatar
Woosuk Kwon committed
205
    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
206
207
208
209
210
211
212
213
214

    # This sample logits processor gives infinite score to the i-th token,
    # where i is the length of the input sequence.
    # We therefore expect the output token sequence to be [0, 1, 2, ...]
    def pick_ith(token_ids, logits):
        logits[len(token_ids)] = float("inf")
        return logits

    seq_group_metadata_list = []
Woosuk Kwon's avatar
Woosuk Kwon committed
215
    prompt_lens = []
216
217
218
219
220
221
222
223
224
225
    for i in range(batch_size):
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=f"test_{i}",
                is_prompt=True,
                seq_data={0: SequenceData([1, 2, 3])},
                sampling_params=SamplingParams(temperature=0,
                                               logits_processors=[pick_ith]),
                block_tables={0: [1]},
            ))
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
227

Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens)
230
231
    sampler_output = sampler(embedding=None,
                             hidden_states=input_tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
232
                             sampling_metadata=sampling_metadata)
233
    for _, sequence_output in enumerate(sampler_output):
234
235
        for idx, nth_output in enumerate(sequence_output.samples):
            assert nth_output.output_token == idx