test_model_runner.py 14.6 KB
Newer Older
1
2
from typing import List

youkaichao's avatar
youkaichao committed
3
import pytest
4
5
import torch

6
7
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
                                             init_distributed_environment)
8
from vllm.engine.arg_utils import EngineArgs
9
from vllm.model_executor.sampling_metadata import SamplingMetadata
10
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
11
from vllm.utils import get_open_port
youkaichao's avatar
youkaichao committed
12
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
13
14


15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
    engine_args = EngineArgs(model, *args, **kwargs)
    engine_config = engine_args.create_engine_config()
    model_runner = ModelRunner(
        model_config=engine_config.model_config,
        parallel_config=engine_config.parallel_config,
        scheduler_config=engine_config.scheduler_config,
        device_config=engine_config.device_config,
        cache_config=engine_config.cache_config,
        load_config=engine_config.load_config,
        lora_config=engine_config.lora_config,
        is_driver_worker=True,
    )
    return model_runner


youkaichao's avatar
youkaichao committed
31
32
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
33
34
35
36
37
38
    model_runner = _create_model_runner(
        "facebook/opt-125m",
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
    )
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
    seq_lens: List[int] = []
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
42
    block_tables = {0: [1]}
43
44
    for i in range(batch_size):
        # make sure all tokens fit into one block
45
46
47
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
48
49
50
51
52
53
54
55
56
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
59
    expected_selected_token_indices = []
    selected_token_start_idx = 0
60
    for seq_len in seq_lens:
61
        expected_selected_token_indices.append(selected_token_start_idx +
62
63
                                               seq_len - 1)
        selected_token_start_idx += seq_len
64
65
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
66
67
68
69
    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    attn_metadata = model_input.attn_metadata
    return_seq_lens = model_input.seq_lens
70
    slot_mapping = attn_metadata.slot_mapping
71
    assert return_seq_lens == seq_lens
72
    assert len(slot_mapping) == len(input_tokens)
73
74
75

    # Verify input metadata is correct for prompts.
    device = model_runner.device
76
77
    assert attn_metadata.num_prefills > 0
    assert attn_metadata.num_decode_tokens == 0
78
79
80
81
    assert torch.allclose(
        attn_metadata.seq_lens_tensor,
        torch.tensor(seq_lens, device=device, dtype=torch.int))
    assert attn_metadata.seq_lens == seq_lens
82
83
    assert attn_metadata.max_prefill_seq_len == max(seq_lens)
    assert attn_metadata.max_decode_seq_len == 0
84
85
86
87

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
88
89
    for seq_len in seq_lens:
        start_idx += seq_len
90
91
        start_loc.append(start_idx)
    assert torch.allclose(
92
        attn_metadata.query_start_loc,
93
94
95
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
96
    # equivalent to query_start_loc.
97
98
    start_idx = 0
    seq_start_loc = [start_idx]
99
100
    for seq_len in seq_lens:
        start_idx += seq_len
101
102
103
        seq_start_loc.append(start_idx)

    assert torch.allclose(
104
        attn_metadata.seq_start_loc,
105
106
        torch.tensor(start_loc, dtype=torch.int32, device=device))
    assert torch.allclose(
107
108
        attn_metadata.context_lens_tensor,
        torch.zeros(attn_metadata.context_lens_tensor.shape[0],
109
110
111
112
113
114
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
115
    assert torch.allclose(attn_metadata.block_tables, expected)
116
    # Cuda graph should not be used for prerill.
117
    assert attn_metadata.use_cuda_graph is False
118

119
120
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
121
122
    torch.testing.assert_close(input_tokens, input_positions)

123
124
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
125
126
        seq_lens,
        query_lens=seq_lens,
127
128
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
129
130
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
131
132
133
134
135
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
136
    torch.allclose(input_tokens, input_positions)
137
138
139
140
141
142
143
144

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


youkaichao's avatar
youkaichao committed
145
146
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
147
    model_runner = _create_model_runner(
148
149
150
151
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
152
153
154
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
155
156
    )

157
158
    context_lens: List[int] = []
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
159
    # Assume each seq group finishes prefill.
160
161
    for i in range(batch_size):
        # make sure all tokens fit into one block
162
163
        context_len = i % (model_runner.block_size - 1) + 1
        context_lens.append(context_len)
164
        seq_data = SequenceData(list(range(context_len)))
165
166
167
        seq_data.update_num_computed_tokens(context_len)
        # Append one token ID since prefill is finished.
        seq_data.append_token_id(1, 0)
168
169
170
171
172
173
174
175
176
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
177

178
179
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
180
181
    input_tokens, input_positions, attn_metadata, slot_mapping = (
        model_input.input_tokens, model_input.input_positions,
182
        model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
183
    assert len(slot_mapping) == len(input_tokens)
184

youkaichao's avatar
youkaichao committed
185
    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
186
187
    # Verify input metadata is correct for prompts.
    device = model_runner.device
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    assert attn_metadata.num_prefills == 0
    assert attn_metadata.num_prefill_tokens == 0
    seq_lens = [context_len + 1 for context_len in context_lens]
    # seq_lens are padded to expected_bs
    for _ in range(expected_bs - len(seq_lens)):
        seq_lens.append(1)
    assert attn_metadata.seq_lens == seq_lens
    start_idx = 0
    start_loc = [start_idx]
    for _ in context_lens:
        # decode has only 1 token for query.
        start_idx += 1
        start_loc.append(start_idx)
    assert torch.allclose(
        attn_metadata.query_start_loc,
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    start_idx = 0
    seq_start_loc = [start_idx]
    for seq_len in seq_lens:
        start_idx += seq_len
        seq_start_loc.append(start_idx)
    assert torch.allclose(
        attn_metadata.seq_start_loc,
        torch.tensor(seq_start_loc, dtype=torch.int32, device=device))

    assert torch.allclose(
        attn_metadata.context_lens_tensor,
        torch.tensor(context_lens, dtype=torch.int, device=device))
    assert attn_metadata.max_decode_seq_len == max(seq_lens)
218
    assert torch.allclose(
219
220
        attn_metadata.seq_lens_tensor[:len(seq_lens)],
        torch.tensor(seq_lens, dtype=torch.int, device=device))
221
222
223

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
224
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
225
226
    # Block table's second dim correspondsd to each token's block number.
    # It is padded up to
227
    assert attn_metadata.block_tables.shape[1] == (
228
        model_runner.get_max_block_per_batch())
229
    assert attn_metadata.use_cuda_graph is True
230

231
232
    assert len(input_tokens) == expected_bs
    assert len(input_positions) == expected_bs
233
    torch.allclose(input_tokens, input_positions)
Woosuk Kwon's avatar
Woosuk Kwon committed
234

235
236
237
    # Verify Sampling
    expected_selected_token_indices = []
    selected_token_start_idx = 0
238
    for _ in context_lens:
239
240
        expected_selected_token_indices.append(selected_token_start_idx)
        selected_token_start_idx += 1
241
242
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
243
        seq_lens,
244
245
        # query lens is all 1 for decode.
        query_lens=[1 for _ in range(len(context_lens))],
246
247
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
Woosuk Kwon's avatar
Woosuk Kwon committed
248
    actual = sampling_metadata.selected_token_indices
249
250
251
252
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
253
254
255
256


def test_empty_seq_group():
    """Verify prepare prompt and decode returns empty output."""
257
    model_runner = _create_model_runner(
258
259
260
261
262
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
    )
263
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
264
265
266
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
    input_tokens, input_positions, attn_metadata = (
267
268
269
270
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
    )
271
272
    assert input_tokens is None
    assert input_positions is None
273
    assert attn_metadata is None
274
275
276
277
278
279
280
281
282
283
284

    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
    (input_tokens, input_positions, attn_metadata, return_seq_lens) = (
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
        model_input.seq_lens,
    )
    assert input_tokens is None
    assert input_positions is None
285
    assert attn_metadata is None
286
    assert return_seq_lens is None
287
288


289
290
291
292
293
294
295
@pytest.fixture
def distributed_init():
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
        local_rank=0)
296
    ensure_model_parallel_initialized(1, 1)
297
298


299
300
301
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
302
    model_runner = _create_model_runner(
303
304
305
306
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=enforce_eager,
307
308
309
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=True,
310
311
312
    )

    # Add prefill requests.
313
314
315
316
    seq_lens: List[int] = []
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    prefill_metadata_list: List[SequenceGroupMetadata] = []
    decode_metadata_list: List[SequenceGroupMetadata] = []
317
318
319
320
321
    block_tables = {0: [1]}
    prefill_batch_size = batch_size // 2
    decode_batch_size = batch_size - prefill_batch_size
    for i in range(prefill_batch_size):
        # make sure all tokens fit into one block
322
323
324
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
        seq_data = SequenceData(list(range(seq_len)))
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
        prefill_metadata_list.append(seq_group_metadata)

    # Add decode requests
    for i in range(prefill_batch_size, batch_size):
        # make sure all tokens fit into one block
339
340
        context_len = i % (model_runner.block_size - 1) + 1
        prompt_toks = list(range(context_len))
341
        seq_data = SequenceData(prompt_toks)
342
343
        seq_data.append_token_id(1, 0)
        seq_data.update_num_computed_tokens(context_len)
344
345
346
347
348
349
350
351
352
353
354
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
        decode_metadata_list.append(seq_group_metadata)

355
356
357
358
359
360
    model_input = model_runner.prepare_model_input(seq_group_metadata_list)
    (input_tokens, input_positions, attn_metadata) = (
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
    )
361
362
363
364
365
366
367

    prefill_meta_actual = attn_metadata.prefill_metadata
    decode_meta_actual = attn_metadata.decode_metadata

    assert len(attn_metadata.slot_mapping) == len(input_tokens)
    assert len(input_positions) == len(input_tokens)
    assert attn_metadata.num_prefills == prefill_batch_size
368
    assert attn_metadata.num_decode_tokens == decode_batch_size
369
    assert attn_metadata.num_prefill_tokens == sum(seq_lens)
370
371
372

    # Verify attn metadata is consistent. We don't need to test individual
    # values here because they are tested above.
373
    attn_metadata = model_runner._prepare_model_input_tensors(
374
        seq_group_metadata_list).attn_metadata
375

376
    for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
377
378
                                          vars(prefill_meta_actual)):
        assert attr_expected[1] == attr_actual[1]
379
    for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
380
381
                                          vars(decode_meta_actual)):
        assert attr_expected[1] == attr_actual[1]