test_model_runner.py 14.3 KB
Newer Older
youkaichao's avatar
youkaichao committed
1
import pytest
2
3
import torch

4
5
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
                                             init_distributed_environment)
6
from vllm.engine.arg_utils import EngineArgs
7
from vllm.model_executor.sampling_metadata import SamplingMetadata
8
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
9
from vllm.utils import get_open_port
youkaichao's avatar
youkaichao committed
10
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
11
12


13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
    engine_args = EngineArgs(model, *args, **kwargs)
    engine_config = engine_args.create_engine_config()
    model_runner = ModelRunner(
        model_config=engine_config.model_config,
        parallel_config=engine_config.parallel_config,
        scheduler_config=engine_config.scheduler_config,
        device_config=engine_config.device_config,
        cache_config=engine_config.cache_config,
        load_config=engine_config.load_config,
        lora_config=engine_config.lora_config,
        is_driver_worker=True,
    )
    return model_runner


youkaichao's avatar
youkaichao committed
29
30
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
31
32
33
34
35
36
    model_runner = _create_model_runner(
        "facebook/opt-125m",
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
    )
Woosuk Kwon's avatar
Woosuk Kwon committed
37

38
    seq_lens = []
39
    seq_group_metadata_list = []
40
    block_tables = {0: [1]}
41
42
    for i in range(batch_size):
        # make sure all tokens fit into one block
43
44
45
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
46
47
48
49
50
51
52
53
54
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
57
    expected_selected_token_indices = []
    selected_token_start_idx = 0
58
    for seq_len in seq_lens:
59
        expected_selected_token_indices.append(selected_token_start_idx +
60
61
                                               seq_len - 1)
        selected_token_start_idx += seq_len
62
63
64
65
66
67
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    attn_metadata = model_input.attn_metadata
    return_seq_lens = model_input.seq_lens
    slot_mapping = model_input.slot_mapping
68
    assert return_seq_lens == seq_lens
69
    assert len(slot_mapping) == len(input_tokens)
70
71
72

    # Verify input metadata is correct for prompts.
    device = model_runner.device
73
74
    assert attn_metadata.num_prefills > 0
    assert attn_metadata.num_decode_tokens == 0
75
76
77
78
    assert torch.allclose(
        attn_metadata.seq_lens_tensor,
        torch.tensor(seq_lens, device=device, dtype=torch.int))
    assert attn_metadata.seq_lens == seq_lens
79
80
    assert attn_metadata.max_prefill_seq_len == max(seq_lens)
    assert attn_metadata.max_decode_seq_len == 0
81
82
83
84

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
85
86
    for seq_len in seq_lens:
        start_idx += seq_len
87
88
        start_loc.append(start_idx)
    assert torch.allclose(
89
        attn_metadata.query_start_loc,
90
91
92
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
93
    # equivalent to query_start_loc.
94
95
    start_idx = 0
    seq_start_loc = [start_idx]
96
97
    for seq_len in seq_lens:
        start_idx += seq_len
98
99
100
        seq_start_loc.append(start_idx)

    assert torch.allclose(
101
        attn_metadata.seq_start_loc,
102
103
        torch.tensor(start_loc, dtype=torch.int32, device=device))
    assert torch.allclose(
104
105
        attn_metadata.context_lens_tensor,
        torch.zeros(attn_metadata.context_lens_tensor.shape[0],
106
107
108
109
110
111
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
112
    assert torch.allclose(attn_metadata.block_tables, expected)
113
    # Cuda graph should not be used for prerill.
114
    assert attn_metadata.use_cuda_graph is False
115

116
117
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
118
119
    torch.testing.assert_close(input_tokens, input_positions)

120
121
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
122
123
        seq_lens,
        query_lens=seq_lens,
124
125
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
126
127
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
128
129
130
131
132
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
133
    torch.allclose(input_tokens, input_positions)
134
135
136
137
138
139
140
141

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


youkaichao's avatar
youkaichao committed
142
143
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
144
    model_runner = _create_model_runner(
145
146
147
148
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
149
150
151
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
152
153
    )

154
    context_lens = []
155
    seq_group_metadata_list = []
156
    # Assume each seq group finishes prefill.
157
158
    for i in range(batch_size):
        # make sure all tokens fit into one block
159
160
161
        context_len = i % (model_runner.block_size - 1) + 1
        context_lens.append(context_len)
        seq_data = list(range(context_len))
162
        seq_data = SequenceData(seq_data)
163
164
165
        seq_data.update_num_computed_tokens(context_len)
        # Append one token ID since prefill is finished.
        seq_data.append_token_id(1, 0)
166
167
168
169
170
171
172
173
174
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
175

176
177
178
179
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    input_tokens, input_positions, attn_metadata, slot_mapping = (
        model_input.input_tokens, model_input.input_positions,
        model_input.attn_metadata, model_input.slot_mapping)
180
    assert len(slot_mapping) == len(input_tokens)
181

youkaichao's avatar
youkaichao committed
182
    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
183
184
    # Verify input metadata is correct for prompts.
    device = model_runner.device
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    assert attn_metadata.num_prefills == 0
    assert attn_metadata.num_prefill_tokens == 0
    seq_lens = [context_len + 1 for context_len in context_lens]
    # seq_lens are padded to expected_bs
    for _ in range(expected_bs - len(seq_lens)):
        seq_lens.append(1)
    assert attn_metadata.seq_lens == seq_lens
    start_idx = 0
    start_loc = [start_idx]
    for _ in context_lens:
        # decode has only 1 token for query.
        start_idx += 1
        start_loc.append(start_idx)
    assert torch.allclose(
        attn_metadata.query_start_loc,
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    start_idx = 0
    seq_start_loc = [start_idx]
    for seq_len in seq_lens:
        start_idx += seq_len
        seq_start_loc.append(start_idx)
    assert torch.allclose(
        attn_metadata.seq_start_loc,
        torch.tensor(seq_start_loc, dtype=torch.int32, device=device))

    assert torch.allclose(
        attn_metadata.context_lens_tensor,
        torch.tensor(context_lens, dtype=torch.int, device=device))
    assert attn_metadata.max_decode_seq_len == max(seq_lens)
215
    assert torch.allclose(
216
217
        attn_metadata.seq_lens_tensor[:len(seq_lens)],
        torch.tensor(seq_lens, dtype=torch.int, device=device))
218
219
220

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
221
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
222
223
    # Block table's second dim correspondsd to each token's block number.
    # It is padded up to
224
    assert attn_metadata.block_tables.shape[1] == (
225
        model_runner.get_max_block_per_batch())
226
    assert attn_metadata.use_cuda_graph is True
227

228
229
    assert len(input_tokens) == expected_bs
    assert len(input_positions) == expected_bs
230
    torch.allclose(input_tokens, input_positions)
Woosuk Kwon's avatar
Woosuk Kwon committed
231

232
233
234
    # Verify Sampling
    expected_selected_token_indices = []
    selected_token_start_idx = 0
235
    for _ in context_lens:
236
237
        expected_selected_token_indices.append(selected_token_start_idx)
        selected_token_start_idx += 1
238
239
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
240
        seq_lens,
241
242
        # query lens is all 1 for decode.
        query_lens=[1 for _ in range(len(context_lens))],
243
244
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
    actual = sampling_metadata.selected_token_indices
246
247
248
249
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
250
251
252
253


def test_empty_seq_group():
    """Verify prepare prompt and decode returns empty output."""
254
    model_runner = _create_model_runner(
255
256
257
258
259
260
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
    )
    seq_group_metadata_list = []
261
262
263
264
265
266
267
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    input_tokens, input_positions, attn_metadata, slot_mapping = (
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
        model_input.slot_mapping,
    )
268
269
270
271
272
    assert len(input_tokens) == 0
    assert len(input_positions) == 0
    assert attn_metadata is None
    assert len(slot_mapping) == 0

273
274
275
276
277
278
279
280
281
    model_input = model_runner._prepare_model_input(seq_group_metadata_list)
    (input_tokens, input_positions, attn_metadata, slot_mapping,
     return_seq_lens) = (
         model_input.input_tokens,
         model_input.input_positions,
         model_input.attn_metadata,
         model_input.slot_mapping,
         model_input.seq_lens,
     )
282
283
284
285
    assert len(input_tokens) == 0
    assert len(input_positions) == 0
    assert attn_metadata is None
    assert len(slot_mapping) == 0
286
    assert len(return_seq_lens) == 0
287
288


289
290
291
292
293
294
295
@pytest.fixture
def distributed_init():
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
        local_rank=0)
296
    ensure_model_parallel_initialized(1, 1)
297
298


299
300
301
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
302
    model_runner = _create_model_runner(
303
304
305
306
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=enforce_eager,
307
308
309
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=True,
310
311
312
    )

    # Add prefill requests.
313
    seq_lens = []
314
315
316
317
318
319
320
321
    seq_group_metadata_list = []
    prefill_metadata_list = []
    decode_metadata_list = []
    block_tables = {0: [1]}
    prefill_batch_size = batch_size // 2
    decode_batch_size = batch_size - prefill_batch_size
    for i in range(prefill_batch_size):
        # make sure all tokens fit into one block
322
323
324
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
        prefill_metadata_list.append(seq_group_metadata)

    # Add decode requests
    for i in range(prefill_batch_size, batch_size):
        # make sure all tokens fit into one block
339
340
        context_len = i % (model_runner.block_size - 1) + 1
        prompt_toks = list(range(context_len))
341
        seq_data = SequenceData(prompt_toks)
342
343
        seq_data.append_token_id(1, 0)
        seq_data.update_num_computed_tokens(context_len)
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
        decode_metadata_list.append(seq_group_metadata)

    (input_tokens, input_positions, attn_metadata, _, _, _,
     _) = model_runner.prepare_input_tensors(seq_group_metadata_list)

    prefill_meta_actual = attn_metadata.prefill_metadata
    decode_meta_actual = attn_metadata.decode_metadata

    assert len(attn_metadata.slot_mapping) == len(input_tokens)
    assert len(input_positions) == len(input_tokens)
    assert attn_metadata.num_prefills == prefill_batch_size
364
    assert attn_metadata.num_decode_tokens == decode_batch_size
365
    assert attn_metadata.num_prefill_tokens == sum(seq_lens)
366
367
368

    # Verify attn metadata is consistent. We don't need to test individual
    # values here because they are tested above.
369
370
    attn_metadata = model_runner._prepare_model_input(
        seq_group_metadata_list).attn_metadata
371

372
    for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
373
374
                                          vars(prefill_meta_actual)):
        assert attr_expected[1] == attr_actual[1]
375
    for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
376
377
                                          vars(decode_meta_actual)):
        assert attr_expected[1] == attr_actual[1]