test_model_runner.py 14.2 KB
Newer Older
youkaichao's avatar
youkaichao committed
1
import pytest
2
3
import torch

4
from vllm.distributed.parallel_state import init_distributed_environment
5
from vllm.engine.arg_utils import EngineArgs
6
from vllm.model_executor.sampling_metadata import SamplingMetadata
7
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
8
from vllm.utils import get_open_port
youkaichao's avatar
youkaichao committed
9
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
10
11


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
    engine_args = EngineArgs(model, *args, **kwargs)
    engine_config = engine_args.create_engine_config()
    model_runner = ModelRunner(
        model_config=engine_config.model_config,
        parallel_config=engine_config.parallel_config,
        scheduler_config=engine_config.scheduler_config,
        device_config=engine_config.device_config,
        cache_config=engine_config.cache_config,
        load_config=engine_config.load_config,
        lora_config=engine_config.lora_config,
        is_driver_worker=True,
    )
    return model_runner


youkaichao's avatar
youkaichao committed
28
29
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
30
31
32
33
34
35
    model_runner = _create_model_runner(
        "facebook/opt-125m",
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
    )
Woosuk Kwon's avatar
Woosuk Kwon committed
36

37
    seq_lens = []
38
    seq_group_metadata_list = []
39
    block_tables = {0: [1]}
40
41
    for i in range(batch_size):
        # make sure all tokens fit into one block
42
43
44
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
45
46
47
48
49
50
51
52
53
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
54

55
56
    expected_selected_token_indices = []
    selected_token_start_idx = 0
57
    for seq_len in seq_lens:
58
        expected_selected_token_indices.append(selected_token_start_idx +
59
60
                                               seq_len - 1)
        selected_token_start_idx += seq_len
61
62
63
64
65
66
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    attn_metadata = model_input.attn_metadata
    return_seq_lens = model_input.seq_lens
    slot_mapping = model_input.slot_mapping
67
    assert return_seq_lens == seq_lens
68
    assert len(slot_mapping) == len(input_tokens)
69
70
71

    # Verify input metadata is correct for prompts.
    device = model_runner.device
72
73
    assert attn_metadata.num_prefills > 0
    assert attn_metadata.num_decode_tokens == 0
74
75
76
77
    assert torch.allclose(
        attn_metadata.seq_lens_tensor,
        torch.tensor(seq_lens, device=device, dtype=torch.int))
    assert attn_metadata.seq_lens == seq_lens
78
79
    assert attn_metadata.max_prefill_seq_len == max(seq_lens)
    assert attn_metadata.max_decode_seq_len == 0
80
81
82
83

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
84
85
    for seq_len in seq_lens:
        start_idx += seq_len
86
87
        start_loc.append(start_idx)
    assert torch.allclose(
88
        attn_metadata.query_start_loc,
89
90
91
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
92
    # equivalent to query_start_loc.
93
94
    start_idx = 0
    seq_start_loc = [start_idx]
95
96
    for seq_len in seq_lens:
        start_idx += seq_len
97
98
99
        seq_start_loc.append(start_idx)

    assert torch.allclose(
100
        attn_metadata.seq_start_loc,
101
102
        torch.tensor(start_loc, dtype=torch.int32, device=device))
    assert torch.allclose(
103
104
        attn_metadata.context_lens_tensor,
        torch.zeros(attn_metadata.context_lens_tensor.shape[0],
105
106
107
108
109
110
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
111
    assert torch.allclose(attn_metadata.block_tables, expected)
112
    # Cuda graph should not be used for prerill.
113
    assert attn_metadata.use_cuda_graph is False
114

115
116
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
117
118
    torch.testing.assert_close(input_tokens, input_positions)

119
120
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
121
122
        seq_lens,
        query_lens=seq_lens,
123
124
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
125
126
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
127
128
129
130
131
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
132
    torch.allclose(input_tokens, input_positions)
133
134
135
136
137
138
139
140

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


youkaichao's avatar
youkaichao committed
141
142
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
143
    model_runner = _create_model_runner(
144
145
146
147
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
148
149
150
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
151
152
    )

153
    context_lens = []
154
    seq_group_metadata_list = []
155
    # Assume each seq group finishes prefill.
156
157
    for i in range(batch_size):
        # make sure all tokens fit into one block
158
159
160
        context_len = i % (model_runner.block_size - 1) + 1
        context_lens.append(context_len)
        seq_data = list(range(context_len))
161
        seq_data = SequenceData(seq_data)
162
163
164
        seq_data.update_num_computed_tokens(context_len)
        # Append one token ID since prefill is finished.
        seq_data.append_token_id(1, 0)
165
166
167
168
169
170
171
172
173
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
174

175
176
177
178
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    input_tokens, input_positions, attn_metadata, slot_mapping = (
        model_input.input_tokens, model_input.input_positions,
        model_input.attn_metadata, model_input.slot_mapping)
179
    assert len(slot_mapping) == len(input_tokens)
180

youkaichao's avatar
youkaichao committed
181
    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
182
183
    # Verify input metadata is correct for prompts.
    device = model_runner.device
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    assert attn_metadata.num_prefills == 0
    assert attn_metadata.num_prefill_tokens == 0
    seq_lens = [context_len + 1 for context_len in context_lens]
    # seq_lens are padded to expected_bs
    for _ in range(expected_bs - len(seq_lens)):
        seq_lens.append(1)
    assert attn_metadata.seq_lens == seq_lens
    start_idx = 0
    start_loc = [start_idx]
    for _ in context_lens:
        # decode has only 1 token for query.
        start_idx += 1
        start_loc.append(start_idx)
    assert torch.allclose(
        attn_metadata.query_start_loc,
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    start_idx = 0
    seq_start_loc = [start_idx]
    for seq_len in seq_lens:
        start_idx += seq_len
        seq_start_loc.append(start_idx)
    assert torch.allclose(
        attn_metadata.seq_start_loc,
        torch.tensor(seq_start_loc, dtype=torch.int32, device=device))

    assert torch.allclose(
        attn_metadata.context_lens_tensor,
        torch.tensor(context_lens, dtype=torch.int, device=device))
    assert attn_metadata.max_decode_seq_len == max(seq_lens)
214
    assert torch.allclose(
215
216
        attn_metadata.seq_lens_tensor[:len(seq_lens)],
        torch.tensor(seq_lens, dtype=torch.int, device=device))
217
218
219

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
220
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
221
222
    # Block table's second dim correspondsd to each token's block number.
    # It is padded up to
223
    assert attn_metadata.block_tables.shape[1] == (
224
        model_runner.get_max_block_per_batch())
225
    assert attn_metadata.use_cuda_graph is True
226

227
228
    assert len(input_tokens) == expected_bs
    assert len(input_positions) == expected_bs
229
    torch.allclose(input_tokens, input_positions)
Woosuk Kwon's avatar
Woosuk Kwon committed
230

231
232
233
    # Verify Sampling
    expected_selected_token_indices = []
    selected_token_start_idx = 0
234
    for _ in context_lens:
235
236
        expected_selected_token_indices.append(selected_token_start_idx)
        selected_token_start_idx += 1
237
238
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
239
        seq_lens,
240
241
        # query lens is all 1 for decode.
        query_lens=[1 for _ in range(len(context_lens))],
242
243
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
Woosuk Kwon's avatar
Woosuk Kwon committed
244
    actual = sampling_metadata.selected_token_indices
245
246
247
248
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
249
250
251
252


def test_empty_seq_group():
    """Verify prepare prompt and decode returns empty output."""
253
    model_runner = _create_model_runner(
254
255
256
257
258
259
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
    )
    seq_group_metadata_list = []
260
261
262
263
264
265
266
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    input_tokens, input_positions, attn_metadata, slot_mapping = (
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
        model_input.slot_mapping,
    )
267
268
269
270
271
    assert len(input_tokens) == 0
    assert len(input_positions) == 0
    assert attn_metadata is None
    assert len(slot_mapping) == 0

272
273
274
275
276
277
278
279
280
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    (input_tokens, input_positions, attn_metadata, slot_mapping,
     return_seq_lens) = (
         model_input.input_tokens,
         model_input.input_positions,
         model_input.attn_metadata,
         model_input.slot_mapping,
         model_input.seq_lens,
     )
281
282
283
284
    assert len(input_tokens) == 0
    assert len(input_positions) == 0
    assert attn_metadata is None
    assert len(slot_mapping) == 0
285
    assert len(return_seq_lens) == 0
286
287


288
289
290
291
292
293
294
@pytest.fixture
def distributed_init():
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
        local_rank=0)
295
296


297
298
299
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
300
    model_runner = _create_model_runner(
301
302
303
304
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=enforce_eager,
305
306
307
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=True,
308
309
310
    )

    # Add prefill requests.
311
    seq_lens = []
312
313
314
315
316
317
318
319
    seq_group_metadata_list = []
    prefill_metadata_list = []
    decode_metadata_list = []
    block_tables = {0: [1]}
    prefill_batch_size = batch_size // 2
    decode_batch_size = batch_size - prefill_batch_size
    for i in range(prefill_batch_size):
        # make sure all tokens fit into one block
320
321
322
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
        prefill_metadata_list.append(seq_group_metadata)

    # Add decode requests
    for i in range(prefill_batch_size, batch_size):
        # make sure all tokens fit into one block
337
338
        context_len = i % (model_runner.block_size - 1) + 1
        prompt_toks = list(range(context_len))
339
        seq_data = SequenceData(prompt_toks)
340
341
        seq_data.append_token_id(1, 0)
        seq_data.update_num_computed_tokens(context_len)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
        decode_metadata_list.append(seq_group_metadata)

    (input_tokens, input_positions, attn_metadata, _, _, _,
     _) = model_runner.prepare_input_tensors(seq_group_metadata_list)

    prefill_meta_actual = attn_metadata.prefill_metadata
    decode_meta_actual = attn_metadata.decode_metadata

    assert len(attn_metadata.slot_mapping) == len(input_tokens)
    assert len(input_positions) == len(input_tokens)
    assert attn_metadata.num_prefills == prefill_batch_size
362
    assert attn_metadata.num_decode_tokens == decode_batch_size
363
    assert attn_metadata.num_prefill_tokens == sum(seq_lens)
364
365
366

    # Verify attn metadata is consistent. We don't need to test individual
    # values here because they are tested above.
367
368
    attn_metadata = model_runner._prepare_model_input(
        seq_group_metadata_list).attn_metadata
369

370
    for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
371
372
                                          vars(prefill_meta_actual)):
        assert attr_expected[1] == attr_actual[1]
373
    for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
374
375
                                          vars(decode_meta_actual)):
        assert attr_expected[1] == attr_actual[1]