test_model_runner.py 13.9 KB
Newer Older
youkaichao's avatar
youkaichao committed
1
import pytest
2
3
import torch

4
from vllm.config import ModelConfig, SchedulerConfig
5
from vllm.distributed.parallel_state import init_distributed_environment
6
from vllm.model_executor.sampling_metadata import SamplingMetadata
7
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
8
from vllm.utils import get_open_port
youkaichao's avatar
youkaichao committed
9
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
10
11


youkaichao's avatar
youkaichao committed
12
13
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
14
15
16
17
    scheduler_config = SchedulerConfig(100000,
                                       100000,
                                       100000,
                                       enable_chunked_prefill=False)
18
19
20
21
22
23
    model_runner = ModelRunner(model_config=None,
                               parallel_config=None,
                               scheduler_config=scheduler_config,
                               device_config=None,
                               load_config=None,
                               lora_config=None)
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
    model_runner.set_block_size(16)

26
    seq_lens = []
27
    seq_group_metadata_list = []
28
    block_tables = {0: [1]}
29
30
    for i in range(batch_size):
        # make sure all tokens fit into one block
31
32
33
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
34
35
36
37
38
39
40
41
42
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
    expected_selected_token_indices = []
    selected_token_start_idx = 0
46
    for seq_len in seq_lens:
47
        expected_selected_token_indices.append(selected_token_start_idx +
48
49
50
51
52
                                               seq_len - 1)
        selected_token_start_idx += seq_len
    (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
     _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
    assert return_seq_lens == seq_lens
53
    assert len(slot_mapping) == len(input_tokens)
54
55
56

    # Verify input metadata is correct for prompts.
    device = model_runner.device
57
    assert attn_metadata.is_prompt is True
58
59
60
61
62
    assert torch.allclose(
        attn_metadata.seq_lens_tensor,
        torch.tensor(seq_lens, device=device, dtype=torch.int))
    assert attn_metadata.seq_lens == seq_lens
    assert attn_metadata.max_seq_len == max(seq_lens)
63
64
65
66

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
67
68
    for seq_len in seq_lens:
        start_idx += seq_len
69
70
        start_loc.append(start_idx)
    assert torch.allclose(
71
        attn_metadata.subquery_start_loc,
72
73
74
75
76
77
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
    # equivalent to subquery_start_loc.
    start_idx = 0
    seq_start_loc = [start_idx]
78
79
    for seq_len in seq_lens:
        start_idx += seq_len
80
81
82
        seq_start_loc.append(start_idx)

    assert torch.allclose(
83
        attn_metadata.seq_start_loc,
84
85
        torch.tensor(start_loc, dtype=torch.int32, device=device))
    assert torch.allclose(
86
87
        attn_metadata.context_lens_tensor,
        torch.zeros(attn_metadata.context_lens_tensor.shape[0],
88
89
90
91
92
93
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
94
    assert torch.allclose(attn_metadata.block_tables, expected)
95
    # Cuda graph should not be used for prerill.
96
    assert attn_metadata.use_cuda_graph is False
97

98
99
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
100
101
    torch.testing.assert_close(input_tokens, input_positions)

102
103
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
104
105
        seq_lens,
        query_lens=seq_lens,
106
107
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
108
109
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
110
111
112
113
114
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
115
    assert input_tokens == input_positions
116
117
118
119
120
121
122
123

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


youkaichao's avatar
youkaichao committed
124
125
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
126
127
128
129
130
131
132
133
134
135
    model_config = ModelConfig(
        "facebook/opt-125m",
        "facebook/opt-125m",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        enforce_eager=False,
    )
136
137
138
139
    scheduler_config = SchedulerConfig(100000,
                                       100000,
                                       100000,
                                       enable_chunked_prefill=False)
140
141
142
143
144
145
    model_runner = ModelRunner(model_config=model_config,
                               parallel_config=None,
                               scheduler_config=scheduler_config,
                               device_config=None,
                               load_config=None,
                               lora_config=None)
146
147
    model_runner.set_block_size(16)

148
    seq_lens = []
149
150
151
    seq_group_metadata_list = []
    for i in range(batch_size):
        # make sure all tokens fit into one block
152
153
154
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = list(range(seq_len))
155
156
157
158
159
160
161
162
163
164
        seq_data = SequenceData(seq_data)
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
165

166
    input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
167
        model_runner._prepare_decode(seq_group_metadata_list))
168
    assert len(slot_mapping) == len(input_tokens)
169

youkaichao's avatar
youkaichao committed
170
    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
171
172
    # Verify input metadata is correct for prompts.
    device = model_runner.device
173
    assert attn_metadata.is_prompt is False
174
    assert attn_metadata.seq_lens is None
175
176
    assert attn_metadata.subquery_start_loc is None
    assert attn_metadata.seq_start_loc is None
177
    assert attn_metadata.max_seq_len == max(seq_lens)
178
    assert torch.allclose(
179
180
        attn_metadata.seq_lens_tensor[:len(seq_lens)],
        torch.tensor(seq_lens, dtype=torch.int, device=device))
181
182
183

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
184
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
185
186
    # Block table's second dim correspondsd to each token's block number.
    # It is padded up to
187
    assert attn_metadata.block_tables.shape[1] == (
188
189
        model_runner.get_max_block_per_batch())
    # Cuda graph should not be used for prerill.
190
    assert attn_metadata.use_cuda_graph is True
191

192
193
194
    assert len(input_tokens) == expected_bs
    assert len(input_positions) == expected_bs
    assert input_tokens == input_positions
Woosuk Kwon's avatar
Woosuk Kwon committed
195

196
197
198
    # Verify Sampling
    expected_selected_token_indices = []
    selected_token_start_idx = 0
199
    for seq_len in seq_lens:
200
201
        expected_selected_token_indices.append(selected_token_start_idx)
        selected_token_start_idx += 1
202
203
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
204
205
        seq_lens,
        query_lens=seq_lens,
206
207
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
Woosuk Kwon's avatar
Woosuk Kwon committed
208
    actual = sampling_metadata.selected_token_indices
209
210
211
212
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
213
214
215
216
217
218
219
220
221
222
223
224
225
226


def test_empty_seq_group():
    """Verify prepare prompt and decode returns empty output."""
    model_config = ModelConfig(
        "facebook/opt-125m",
        "facebook/opt-125m",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        enforce_eager=False,
    )
227
228
229
230
231
232
    model_runner = ModelRunner(model_config=model_config,
                               parallel_config=None,
                               scheduler_config=None,
                               device_config=None,
                               load_config=None,
                               lora_config=None)
233
234
235
236
237
238
239
240
241
    model_runner.set_block_size(16)
    seq_group_metadata_list = []
    input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
        model_runner._prepare_decode(seq_group_metadata_list))
    assert len(input_tokens) == 0
    assert len(input_positions) == 0
    assert attn_metadata is None
    assert len(slot_mapping) == 0

242
243
    (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
     _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
244
245
246
247
    assert len(input_tokens) == 0
    assert len(input_positions) == 0
    assert attn_metadata is None
    assert len(slot_mapping) == 0
248
    assert len(return_seq_lens) == 0
249
250


251
252
253
254
255
256
257
@pytest.fixture
def distributed_init():
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
        local_rank=0)
258
259


260
261
262
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    model_config = ModelConfig(
        "facebook/opt-125m",
        "facebook/opt-125m",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        enforce_eager=enforce_eager,
    )
    scheduler_config = SchedulerConfig(100000,
                                       100000,
                                       100000,
                                       enable_chunked_prefill=True)
278
279
280
281
282
283
    model_runner = ModelRunner(model_config=model_config,
                               parallel_config=None,
                               scheduler_config=scheduler_config,
                               device_config=None,
                               load_config=None,
                               lora_config=None,
284
285
286
287
                               is_driver_worker=True)
    model_runner.set_block_size(16)

    # Add prefill requests.
288
    seq_lens = []
289
290
291
292
293
294
295
296
    seq_group_metadata_list = []
    prefill_metadata_list = []
    decode_metadata_list = []
    block_tables = {0: [1]}
    prefill_batch_size = batch_size // 2
    decode_batch_size = batch_size - prefill_batch_size
    for i in range(prefill_batch_size):
        # make sure all tokens fit into one block
297
298
299
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
        prefill_metadata_list.append(seq_group_metadata)

    # Add decode requests
    for i in range(prefill_batch_size, batch_size):
        # make sure all tokens fit into one block
314
315
        seq_len = i % (model_runner.block_size - 1) + 1
        prompt_toks = list(range(seq_len))
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        seq_data = SequenceData(prompt_toks)
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
        decode_metadata_list.append(seq_group_metadata)

    (input_tokens, input_positions, attn_metadata, _, _, _,
     _) = model_runner.prepare_input_tensors(seq_group_metadata_list)

    prefill_meta_actual = attn_metadata.prefill_metadata
    decode_meta_actual = attn_metadata.decode_metadata

    assert len(attn_metadata.slot_mapping) == len(input_tokens)
    assert len(input_positions) == len(input_tokens)
    assert attn_metadata.kv_cache_dtype == "auto"
    assert attn_metadata.num_prefills == prefill_batch_size
    if enforce_eager:
        assert attn_metadata.num_decode_tokens == decode_batch_size
    else:
        assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
            decode_batch_size)
343
    assert attn_metadata.num_prefill_tokens == sum(seq_lens)
344
345
346
347
348
349
350
351
352
353
354
355
356
357

    # Verify attn metadata is consistent. We don't need to test individual
    # values here because they are tested above.
    prefill_meta = model_runner._prepare_prompt(
        prefill_metadata_list).attn_metadata
    decode_meta = model_runner._prepare_decode(
        decode_metadata_list).attn_metadata

    for attr_expected, attr_actual in zip(vars(prefill_meta),
                                          vars(prefill_meta_actual)):
        assert attr_expected[1] == attr_actual[1]
    for attr_expected, attr_actual in zip(vars(decode_meta),
                                          vars(decode_meta_actual)):
        assert attr_expected[1] == attr_actual[1]