test_model_runner.py 17.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

youkaichao's avatar
youkaichao committed
4
import pytest
5
import torch
6
import os
7

8
9
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
                                             init_distributed_environment)
10
from vllm.engine.arg_utils import EngineArgs
11
from vllm.model_executor.sampling_metadata import SamplingMetadata
12
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
13
from vllm.utils import get_open_port
zhuwenwen's avatar
zhuwenwen committed
14

15
from vllm.worker.model_runner import ModelRunner
16
from ..utils import models_path_prefix
17
18


19
20
21
22
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
    engine_args = EngineArgs(model, *args, **kwargs)
    engine_config = engine_args.create_engine_config()
    model_runner = ModelRunner(
23
        vllm_config=engine_config,
24
25
26
27
28
        is_driver_worker=True,
    )
    return model_runner


29
30
def test_deepseek_mla_attn_backend_module():
    model_runner = _create_model_runner(
31
        os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
32
33
34
35
36
37
        trust_remote_code=True,
        enable_chunked_prefill=False,
    )
    assert model_runner.attn_backend.__name__ == "TritonMLABackend"


38
39
40
41
42
43
44
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
    if use_prompt_embeds:
        # Prompt Embeddings is only currently supported on V0
        monkeypatch.setenv("VLLM_USE_V1", "0")

45
    model_runner = _create_model_runner(
46
        os.path.join(models_path_prefix, "facebook/opt-125m"),
47
48
49
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
50
        enable_prompt_embeds=True,
51
    )
Woosuk Kwon's avatar
Woosuk Kwon committed
52

53
54
    seq_lens: list[int] = []
    seq_group_metadata_list: list[SequenceGroupMetadata] = []
55
    block_tables = {0: [1]}
56
    expected_input_embeds_len = 0
57
58
    for i in range(batch_size):
        # make sure all tokens fit into one block
59
60
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
61
62
63
64
65
66
67
68
69
        if use_prompt_embeds:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=[0] * seq_len,
                prompt_embeds=torch.rand(seq_len, 10),
            )
            expected_input_embeds_len += seq_len
        else:
            seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))

70
71
72
73
74
75
76
77
78
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
79

80
81
    expected_selected_token_indices = []
    selected_token_start_idx = 0
82
    for seq_len in seq_lens:
83
        expected_selected_token_indices.append(selected_token_start_idx +
84
85
                                               seq_len - 1)
        selected_token_start_idx += seq_len
86
87
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
88
89
    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
90
    input_embeds = model_input.inputs_embeds
91
92
    attn_metadata = model_input.attn_metadata
    return_seq_lens = model_input.seq_lens
93
    slot_mapping = attn_metadata.slot_mapping
94
    assert return_seq_lens == seq_lens
95
    assert len(slot_mapping) == len(input_tokens)
96
97
98

    # Verify input metadata is correct for prompts.
    device = model_runner.device
99
100
    assert attn_metadata.num_prefills > 0
    assert attn_metadata.num_decode_tokens == 0
101
    torch.testing.assert_close(
102
103
104
        attn_metadata.seq_lens_tensor,
        torch.tensor(seq_lens, device=device, dtype=torch.int))
    assert attn_metadata.seq_lens == seq_lens
105
106
    assert attn_metadata.max_prefill_seq_len == max(seq_lens)
    assert attn_metadata.max_decode_seq_len == 0
107
108
109
110

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
111
112
    for seq_len in seq_lens:
        start_idx += seq_len
113
        start_loc.append(start_idx)
114
    torch.testing.assert_close(
115
        attn_metadata.query_start_loc,
116
117
118
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
119
    # equivalent to query_start_loc.
120
121
    start_idx = 0
    seq_start_loc = [start_idx]
122
123
    for seq_len in seq_lens:
        start_idx += seq_len
124
125
        seq_start_loc.append(start_idx)

126
    torch.testing.assert_close(
127
        attn_metadata.seq_start_loc,
128
        torch.tensor(start_loc, dtype=torch.int32, device=device))
129
    torch.testing.assert_close(
130
131
        attn_metadata.context_lens_tensor,
        torch.zeros(attn_metadata.context_lens_tensor.shape[0],
132
133
134
135
136
137
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
138
    torch.testing.assert_close(attn_metadata.block_tables, expected)
139
    # Cuda graph should not be used for prerill.
140
    assert attn_metadata.use_cuda_graph is False
141

142
143
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
144
145
146
147
148
    if expected_input_embeds_len == 0:
        torch.testing.assert_close(input_tokens, input_positions)
        assert input_embeds is None
    else:
        assert len(input_embeds) == expected_input_embeds_len
149

150
151
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
152
153
        seq_lens,
        query_lens=seq_lens,
154
155
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
156
157
    assert len(input_tokens) == sum(seq_lens)
    assert len(input_positions) == sum(seq_lens)
158
159
160
161
162
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
163
    torch.allclose(input_tokens, input_positions)
164
165
166
167
168
169
170
171

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


172
173
174
175
176
177
178
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
    if use_prompt_embeds:
        # Prompt Embeddings is only currently supported on V0
        monkeypatch.setenv("VLLM_USE_V1", "0")

179
    model_runner = _create_model_runner(
180
        os.path.join(models_path_prefix, "facebook/opt-125m"),
181
182
183
        seed=0,
        dtype="float16",
        enforce_eager=False,
184
185
186
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=False,
187
        enable_prompt_embeds=True,
188
189
    )

190
191
    context_lens: list[int] = []
    seq_group_metadata_list: list[SequenceGroupMetadata] = []
192
    # Assume each seq group finishes prefill.
193
194
    for i in range(batch_size):
        # make sure all tokens fit into one block
195
196
        context_len = i % (model_runner.block_size - 1) + 1
        context_lens.append(context_len)
197
198
199
200
201
202
203
204
205
206
        if use_prompt_embeds:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=[0] * context_len,
                prompt_embeds=torch.rand(context_len, 10),
            )
            output_embed = torch.rand(10)
        else:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=range(context_len))
            output_embed = None
207
208
        seq_data.update_num_computed_tokens(context_len)
        # Append one token ID since prefill is finished.
209
        seq_data.append_token_id(1, 0, output_embed)
210
211
212
213
214
215
216
217
218
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
219

220
221
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
222
223
224
225
226
227
    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    input_embeds = model_input.inputs_embeds
    attn_metadata = model_input.attn_metadata
    slot_mapping = attn_metadata.slot_mapping

228
    assert len(slot_mapping) == len(input_tokens)
229

230
231
    expected_bs = model_runner.vllm_config.pad_for_cudagraph(
        len(seq_group_metadata_list))
232
233
    # Verify input metadata is correct for prompts.
    device = model_runner.device
234
235
236
237
238
239
240
    assert attn_metadata.num_prefills == 0
    assert attn_metadata.num_prefill_tokens == 0
    seq_lens = [context_len + 1 for context_len in context_lens]
    # seq_lens are padded to expected_bs
    for _ in range(expected_bs - len(seq_lens)):
        seq_lens.append(1)
    assert attn_metadata.seq_lens == seq_lens
241
    assert attn_metadata.num_decode_tokens == len(seq_lens)
242
243
244
245
246
247
    start_idx = 0
    start_loc = [start_idx]
    for _ in context_lens:
        # decode has only 1 token for query.
        start_idx += 1
        start_loc.append(start_idx)
248
    torch.testing.assert_close(
249
250
251
252
253
254
255
256
        attn_metadata.query_start_loc,
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    start_idx = 0
    seq_start_loc = [start_idx]
    for seq_len in seq_lens:
        start_idx += seq_len
        seq_start_loc.append(start_idx)
257
    torch.testing.assert_close(
258
259
260
        attn_metadata.seq_start_loc,
        torch.tensor(seq_start_loc, dtype=torch.int32, device=device))

261
    torch.testing.assert_close(
262
263
264
        attn_metadata.context_lens_tensor,
        torch.tensor(context_lens, dtype=torch.int, device=device))
    assert attn_metadata.max_decode_seq_len == max(seq_lens)
265
    torch.testing.assert_close(
266
267
        attn_metadata.seq_lens_tensor[:len(seq_lens)],
        torch.tensor(seq_lens, dtype=torch.int, device=device))
268
269
270

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
271
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
272
    # Block table's second dim corresponds to each token's block number.
273
    # It is padded up to
274
    assert attn_metadata.block_tables.shape[1] == (
275
        model_runner.get_max_block_per_batch())
276
    assert attn_metadata.use_cuda_graph is True
277

278
279
    assert len(input_tokens) == expected_bs
    assert len(input_positions) == expected_bs
280
281
282
283
284
285
    if use_prompt_embeds:
        expected_input_embeds_length = start_loc[-1]
        assert len(input_embeds) == expected_input_embeds_length
        assert expected_input_embeds_length <= expected_bs
    else:
        assert input_embeds is None
Woosuk Kwon's avatar
Woosuk Kwon committed
286

287
288
    # Verify Sampling
    expected_selected_token_indices = []
289
    for selected_token_start_idx, _ in enumerate(context_lens):
290
        expected_selected_token_indices.append(selected_token_start_idx)
291
292
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list,
293
        seq_lens,
294
295
        # query lens is all 1 for decode.
        query_lens=[1 for _ in range(len(context_lens))],
296
297
        device=model_runner.device,
        pin_memory=model_runner.pin_memory)
Woosuk Kwon's avatar
Woosuk Kwon committed
298
    actual = sampling_metadata.selected_token_indices
299
300
301
302
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
303
304
305
306


def test_empty_seq_group():
    """Verify prepare prompt and decode returns empty output."""
307
    model_runner = _create_model_runner(
308
        os.path.join(models_path_prefix, "facebook/opt-125m"),
309
310
311
312
        seed=0,
        dtype="float16",
        enforce_eager=False,
    )
313
    seq_group_metadata_list: list[SequenceGroupMetadata] = []
314
315
    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
316
317
318
319
320

    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    attn_metadata = model_input.attn_metadata

321
322
    assert input_tokens is None
    assert input_positions is None
323
    assert attn_metadata is None
324
325
326

    model_input = model_runner._prepare_model_input_tensors(
        seq_group_metadata_list)
327
328
329
330
331
332
333

    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    input_embeds = model_input.inputs_embeds
    attn_metadata = model_input.attn_metadata
    return_seq_lens = model_input.seq_lens

334
335
    assert input_tokens is None
    assert input_positions is None
336
    assert input_embeds is None
337
    assert attn_metadata is None
338
    assert return_seq_lens is None
339
340


341
342
343
344
345
346
347
@pytest.fixture
def distributed_init():
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
        local_rank=0)
348
    ensure_model_parallel_initialized(1, 1)
349
350


351
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
352
@pytest.mark.parametrize("enforce_eager", [True, False])
353
354
355
356
357
358
359
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
                        distributed_init, monkeypatch):
    if use_prompt_embeds:
        # Prompt Embeddings is only currently supported on V0
        monkeypatch.setenv("VLLM_USE_V1", "0")

360
    model_runner = _create_model_runner(
361
        os.path.join(models_path_prefix, "facebook/opt-125m"),
362
363
364
        seed=0,
        dtype="float16",
        enforce_eager=enforce_eager,
365
366
367
        max_num_batched_tokens=100000,
        max_num_seqs=100000,
        enable_chunked_prefill=True,
368
        enable_prompt_embeds=True,
369
370
371
    )

    # Add prefill requests.
372
373
374
375
    seq_lens: list[int] = []
    seq_group_metadata_list: list[SequenceGroupMetadata] = []
    prefill_metadata_list: list[SequenceGroupMetadata] = []
    decode_metadata_list: list[SequenceGroupMetadata] = []
376
377
378
    block_tables = {0: [1]}
    prefill_batch_size = batch_size // 2
    decode_batch_size = batch_size - prefill_batch_size
379
    expected_input_embeds_len = 0
380
381
    for i in range(prefill_batch_size):
        # make sure all tokens fit into one block
382
383
        seq_len = i % (model_runner.block_size - 1) + 1
        seq_lens.append(seq_len)
384
385
386
387
388
389
390
391
392
        if use_prompt_embeds:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=[0] * seq_len,
                prompt_embeds=torch.rand(seq_len, 10),
            )
            expected_input_embeds_len += seq_len
        else:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=range(seq_len), )
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
        prefill_metadata_list.append(seq_group_metadata)

    # Add decode requests
    for i in range(prefill_batch_size, batch_size):
        # make sure all tokens fit into one block
407
        context_len = i % (model_runner.block_size - 1) + 1
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        if use_prompt_embeds:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=[0] * context_len,
                prompt_embeds=torch.rand(context_len, 10),
            )
            output_embed = torch.rand(10)
            # This also iterates the expected input_embeds, because the model
            # needs both the input and output embeddings passed into together
            expected_input_embeds_len += 1
        else:
            seq_data = SequenceData.from_seqs(
                prompt_token_ids=range(context_len), )
            output_embed = None
        assert len(seq_data.prompt_token_ids) == context_len
        seq_data.append_token_id(1, 0, output_embed)
423
        seq_data.update_num_computed_tokens(context_len)
424
425
426
427
428
429
430
431
432
433
434
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
        decode_metadata_list.append(seq_group_metadata)

435
    model_input = model_runner.prepare_model_input(seq_group_metadata_list)
436
437
438
439
440

    input_tokens = model_input.input_tokens
    input_positions = model_input.input_positions
    input_embeds = model_input.inputs_embeds
    attn_metadata = model_input.attn_metadata
441
442
443
444
445
446
447

    prefill_meta_actual = attn_metadata.prefill_metadata
    decode_meta_actual = attn_metadata.decode_metadata

    assert len(attn_metadata.slot_mapping) == len(input_tokens)
    assert len(input_positions) == len(input_tokens)
    assert attn_metadata.num_prefills == prefill_batch_size
448
    assert attn_metadata.num_decode_tokens == decode_batch_size
449
    assert attn_metadata.num_prefill_tokens == sum(seq_lens)
450
451
452
453
    if expected_input_embeds_len == 0:
        assert input_embeds is None
    else:
        assert len(input_embeds) == expected_input_embeds_len
454
455
456

    # Verify attn metadata is consistent. We don't need to test individual
    # values here because they are tested above.
457
    attn_metadata = model_runner._prepare_model_input_tensors(
458
        seq_group_metadata_list).attn_metadata
459

460
    for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
461
462
                                          vars(prefill_meta_actual)):
        assert attr_expected[1] == attr_actual[1]
463
    for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
464
                                          vars(decode_meta_actual)):
465
        assert attr_expected[1] == attr_actual[1]