test_model_runner.py 15 KB
Newer Older
1
from array import array
2
3
from typing import List

youkaichao's avatar
youkaichao committed
4
import pytest
5
6
import torch

7
8
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
                                             init_distributed_environment)
9
from vllm.engine.arg_utils import EngineArgs
10
from vllm.model_executor.sampling_metadata import SamplingMetadata
11
12
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
                           SequenceData, SequenceGroupMetadata)
13
from vllm.utils import get_open_port
youkaichao's avatar
youkaichao committed
14
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
15
16


17
18
19
20
21
22
23
24
25
26
27
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
    engine_args = EngineArgs(model, *args, **kwargs)
    engine_config = engine_args.create_engine_config()
    model_runner = ModelRunner(
        model_config=engine_config.model_config,
        parallel_config=engine_config.parallel_config,
        scheduler_config=engine_config.scheduler_config,
        device_config=engine_config.device_config,
        cache_config=engine_config.cache_config,
        load_config=engine_config.load_config,
        lora_config=engine_config.lora_config,
28
        prompt_adapter_config=engine_config.prompt_adapter_config,
29
        observability_config=engine_config.observability_config,
30
31
32
33
34
        is_driver_worker=True,
    )
    return model_runner


youkaichao's avatar
youkaichao committed
35
36
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
37
38
39
40
41
42
    model_runner = _create_model_runner(
        "facebook/opt-125m",
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
    )
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
    seq_lens: List[int] = []
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
46
    block_tables = {0: [1]}
47
48
    for i in range(batch_size):
        # make sure all tokens fit into one block
49
50
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
51
52
        seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                      range(seq_len)))
53
54
55
56
57
58
59
60
61
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
64
    expected_selected_token_indices = []
    selected_token_start_idx = 0
65
    for seq_len in seq_lens:
66
        expected_selected_token_indices.append(selected_token_start_idx +
67
68
                                               seq_len - 1)
        selected_token_start_idx += seq_len
69
70
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
71
72
73
74
    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    attn_metadata = model_input.attn_metadata
    return_seq_lens = model_input.seq_lens
75
    slot_mapping = attn_metadata.slot_mapping
76
    assert return_seq_lens == seq_lens
77
    assert len(slot_mapping) == len(input_tokens)
78
79
80

    # Verify input metadata is correct for prompts.
    device = model_runner.device
81
82
    assert attn_metadata.num_prefills > 0
    assert attn_metadata.num_decode_tokens == 0
83
    torch.testing.assert_close(
84
85
86
        attn_metadata.seq_lens_tensor,
        torch.tensor(seq_lens, device=device, dtype=torch.int))
    assert attn_metadata.seq_lens == seq_lens
87
88
    assert attn_metadata.max_prefill_seq_len == max(seq_lens)
    assert attn_metadata.max_decode_seq_len == 0
89
90
91
92

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
93
94
    for seq_len in seq_lens:
        start_idx += seq_len
95
        start_loc.append(start_idx)
96
    torch.testing.assert_close(
97
        attn_metadata.query_start_loc,
98
99
100
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
101
    # equivalent to query_start_loc.
102
103
    start_idx = 0
    seq_start_loc = [start_idx]
104
105
    for seq_len in seq_lens:
        start_idx += seq_len
106
107
        seq_start_loc.append(start_idx)

108
    torch.testing.assert_close(
109
        attn_metadata.seq_start_loc,
110
        torch.tensor(start_loc, dtype=torch.int32, device=device))
111
    torch.testing.assert_close(
112
113
        attn_metadata.context_lens_tensor,
        torch.zeros(attn_metadata.context_lens_tensor.shape[0],
114
115
116
117
118
119
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
120
    torch.testing.assert_close(attn_metadata.block_tables, expected)
121
    # Cuda graph should not be used for prerill.
122
    assert attn_metadata.use_cuda_graph is False
123

124
125
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
126
127
    torch.testing.assert_close(input_tokens, input_positions)

128
129
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
130
131
        seq_lens,
        query_lens=seq_lens,
132
133
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
134
135
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
136
137
138
139
140
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
141
    torch.allclose(input_tokens, input_positions)
142
143
144
145
146
147
148
149

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


youkaichao's avatar
youkaichao committed
150
151
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
152
    model_runner = _create_model_runner(
153
154
155
156
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
157
158
159
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
160
161
    )

162
163
    context_lens: List[int] = []
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
164
    # Assume each seq group finishes prefill.
165
166
    for i in range(batch_size):
        # make sure all tokens fit into one block
167
168
        context_len = i % (model_runner.block_size - 1) + 1
        context_lens.append(context_len)
169
170
        seq_data = SequenceData(
            array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
171
172
173
        seq_data.update_num_computed_tokens(context_len)
        # Append one token ID since prefill is finished.
        seq_data.append_token_id(1, 0)
174
175
176
177
178
179
180
181
182
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
183

184
185
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
186
187
    input_tokens, input_positions, attn_metadata, slot_mapping = (
        model_input.input_tokens, model_input.input_positions,
188
        model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
189
    assert len(slot_mapping) == len(input_tokens)
190

youkaichao's avatar
youkaichao committed
191
    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
192
193
    # Verify input metadata is correct for prompts.
    device = model_runner.device
194
195
196
197
198
199
200
    assert attn_metadata.num_prefills == 0
    assert attn_metadata.num_prefill_tokens == 0
    seq_lens = [context_len + 1 for context_len in context_lens]
    # seq_lens are padded to expected_bs
    for _ in range(expected_bs - len(seq_lens)):
        seq_lens.append(1)
    assert attn_metadata.seq_lens == seq_lens
201
    assert attn_metadata.num_decode_tokens == len(seq_lens)
202
203
204
205
206
207
    start_idx = 0
    start_loc = [start_idx]
    for _ in context_lens:
        # decode has only 1 token for query.
        start_idx += 1
        start_loc.append(start_idx)
208
    torch.testing.assert_close(
209
210
211
212
213
214
215
216
        attn_metadata.query_start_loc,
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    start_idx = 0
    seq_start_loc = [start_idx]
    for seq_len in seq_lens:
        start_idx += seq_len
        seq_start_loc.append(start_idx)
217
    torch.testing.assert_close(
218
219
220
        attn_metadata.seq_start_loc,
        torch.tensor(seq_start_loc, dtype=torch.int32, device=device))

221
    torch.testing.assert_close(
222
223
224
        attn_metadata.context_lens_tensor,
        torch.tensor(context_lens, dtype=torch.int, device=device))
    assert attn_metadata.max_decode_seq_len == max(seq_lens)
225
    torch.testing.assert_close(
226
227
        attn_metadata.seq_lens_tensor[:len(seq_lens)],
        torch.tensor(seq_lens, dtype=torch.int, device=device))
228
229
230

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
231
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
232
233
    # Block table's second dim correspondsd to each token's block number.
    # It is padded up to
234
    assert attn_metadata.block_tables.shape[1] == (
235
        model_runner.get_max_block_per_batch())
236
    assert attn_metadata.use_cuda_graph is True
237

238
239
    assert len(input_tokens) == expected_bs
    assert len(input_positions) == expected_bs
240
    torch.allclose(input_tokens, input_positions)
Woosuk Kwon's avatar
Woosuk Kwon committed
241

242
243
    # Verify Sampling
    expected_selected_token_indices = []
244
    for selected_token_start_idx, _ in enumerate(context_lens):
245
        expected_selected_token_indices.append(selected_token_start_idx)
246
247
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
248
        seq_lens,
249
250
        # query lens is all 1 for decode.
        query_lens=[1 for _ in range(len(context_lens))],
251
252
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
Woosuk Kwon's avatar
Woosuk Kwon committed
253
    actual = sampling_metadata.selected_token_indices
254
255
256
257
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
258
259
260
261


def test_empty_seq_group():
    """Verify prepare prompt and decode returns empty output."""
262
    model_runner = _create_model_runner(
263
264
265
266
267
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=False,
    )
268
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
269
270
271
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
    input_tokens, input_positions, attn_metadata = (
272
273
274
275
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
    )
276
277
    assert input_tokens is None
    assert input_positions is None
278
    assert attn_metadata is None
279
280
281
282
283
284
285
286
287
288
289

    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
    (input_tokens, input_positions, attn_metadata, return_seq_lens) = (
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
        model_input.seq_lens,
    )
    assert input_tokens is None
    assert input_positions is None
290
    assert attn_metadata is None
291
    assert return_seq_lens is None
292
293


294
295
296
297
298
299
300
@pytest.fixture
def distributed_init():
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
        local_rank=0)
301
    ensure_model_parallel_initialized(1, 1)
302
303


304
305
306
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
307
    model_runner = _create_model_runner(
308
309
310
311
        "facebook/opt-125m",
        seed=0,
        dtype="float16",
        enforce_eager=enforce_eager,
312
313
314
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=True,
315
316
317
    )

    # Add prefill requests.
318
319
320
321
    seq_lens: List[int] = []
    seq_group_metadata_list: List[SequenceGroupMetadata] = []
    prefill_metadata_list: List[SequenceGroupMetadata] = []
    decode_metadata_list: List[SequenceGroupMetadata] = []
322
323
324
325
326
    block_tables = {0: [1]}
    prefill_batch_size = batch_size // 2
    decode_batch_size = batch_size - prefill_batch_size
    for i in range(prefill_batch_size):
        # make sure all tokens fit into one block
327
328
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
329
330
        seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                      range(seq_len)))
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
        prefill_metadata_list.append(seq_group_metadata)

    # Add decode requests
    for i in range(prefill_batch_size, batch_size):
        # make sure all tokens fit into one block
345
        context_len = i % (model_runner.block_size - 1) + 1
346
        prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
347
        seq_data = SequenceData(prompt_toks)
348
349
        seq_data.append_token_id(1, 0)
        seq_data.update_num_computed_tokens(context_len)
350
351
352
353
354
355
356
357
358
359
360
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
        decode_metadata_list.append(seq_group_metadata)

361
362
363
364
365
366
    model_input = model_runner.prepare_model_input(seq_group_metadata_list)
    (input_tokens, input_positions, attn_metadata) = (
        model_input.input_tokens,
        model_input.input_positions,
        model_input.attn_metadata,
    )
367
368
369
370
371
372
373

    prefill_meta_actual = attn_metadata.prefill_metadata
    decode_meta_actual = attn_metadata.decode_metadata

    assert len(attn_metadata.slot_mapping) == len(input_tokens)
    assert len(input_positions) == len(input_tokens)
    assert attn_metadata.num_prefills == prefill_batch_size
374
    assert attn_metadata.num_decode_tokens == decode_batch_size
375
    assert attn_metadata.num_prefill_tokens == sum(seq_lens)
376
377
378

    # Verify attn metadata is consistent. We don't need to test individual
    # values here because they are tested above.
379
    attn_metadata = model_runner._prepare_model_input_tensors(
380
        seq_group_metadata_list).attn_metadata
381

382
    for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
383
384
                                          vars(prefill_meta_actual)):
        assert attr_expected[1] == attr_actual[1]
385
    for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
386
387
                                          vars(decode_meta_actual)):
        assert attr_expected[1] == attr_actual[1]