test_model_runner.py 7.82 KB
Newer Older
youkaichao's avatar
youkaichao committed
1
import pytest
2
3
import torch

4
from vllm.config import ModelConfig
5
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
youkaichao's avatar
youkaichao committed
6
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
7
8


youkaichao's avatar
youkaichao committed
9
10
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
11
    model_runner = ModelRunner(None, None, None, None, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
    model_runner.set_block_size(16)

14
15
    prompt_lens = []
    seq_group_metadata_list = []
16
    block_tables = {0: [1]}
17
18
    for i in range(batch_size):
        # make sure all tokens fit into one block
Woosuk Kwon's avatar
Woosuk Kwon committed
19
        prompt_len = i % (model_runner.block_size - 1) + 1
20
        prompt_lens.append(prompt_len)
21
22
23
24
25
26
27
28
29
30
        seq_data = SequenceData(list(range(prompt_len)))
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=True,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables=block_tables,
        )
        assert seq_group_metadata.token_chunk_size == seq_data.get_len()
        seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
33
34
35
36
    expected_selected_token_indices = []
    selected_token_start_idx = 0
    for prompt_len in prompt_lens:
        expected_selected_token_indices.append(selected_token_start_idx +
                                               prompt_len - 1)
37
        selected_token_start_idx += prompt_len
38
    (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
39
     _, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
40
    assert return_prompt_lens == prompt_lens
41
42
43

    # Verify input metadata is correct for prompts.
    device = model_runner.device
44
45
    assert attn_metadata.is_prompt is True
    assert torch.allclose(attn_metadata.prompt_lens_tensor,
46
                          torch.tensor(prompt_lens, device=device))
47
48
49
50
    assert attn_metadata.prompt_lens == prompt_lens
    assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
    assert attn_metadata.num_generation_tokens == 0
    assert attn_metadata.max_prompt_len == max(prompt_lens)
51
52
53
54
55
56
57
58

    # Test subquery start locs.
    start_idx = 0
    start_loc = [start_idx]
    for prompt_len in prompt_lens:
        start_idx += prompt_len
        start_loc.append(start_idx)
    assert torch.allclose(
59
        attn_metadata.subquery_start_loc,
60
61
62
63
64
65
66
67
68
69
70
        torch.tensor(start_loc, dtype=torch.int32, device=device))

    # Test seq start locs. Note that for normal prefill it is
    # equivalent to subquery_start_loc.
    start_idx = 0
    seq_start_loc = [start_idx]
    for prompt_len in prompt_lens:
        start_idx += prompt_len
        seq_start_loc.append(start_idx)

    assert torch.allclose(
71
        attn_metadata.seq_start_loc,
72
        torch.tensor(start_loc, dtype=torch.int32, device=device))
73
    assert attn_metadata.max_context_len is None
74
    assert torch.allclose(
75
76
        attn_metadata.context_lens,
        torch.zeros(attn_metadata.context_lens.shape[0],
77
78
79
80
81
82
                    dtype=torch.int,
                    device=device))

    expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
                            dtype=torch.int32,
                            device=model_runner.device)
83
    assert torch.allclose(attn_metadata.block_tables, expected)
84
    # Cuda graph should not be used for prerill.
85
86
    assert attn_metadata.use_cuda_graph is False
    assert attn_metadata.kv_cache_dtype == "auto"
87
88
89
90
91

    assert input_tokens.shape == (sum(prompt_lens), )
    assert input_positions.shape == (sum(prompt_lens), )
    torch.testing.assert_close(input_tokens, input_positions)

Woosuk Kwon's avatar
Woosuk Kwon committed
92
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
93
94
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    assert input_tokens.shape == (sum(prompt_lens), )
    assert input_positions.shape == (sum(prompt_lens), )
    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)
    torch.testing.assert_close(input_tokens, input_positions)

    actual = sampling_metadata.selected_token_indices
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)


youkaichao's avatar
youkaichao committed
111
112
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    model_config = ModelConfig(
        "facebook/opt-125m",
        "facebook/opt-125m",
        tokenizer_mode="auto",
        trust_remote_code=False,
        download_dir=None,
        load_format="dummy",
        seed=0,
        dtype="float16",
        revision=None,
        enforce_eager=False,
    )
    model_runner = ModelRunner(model_config, None, None, None, None)
    model_runner.set_block_size(16)

    prompt_lens = []
    seq_group_metadata_list = []
    for i in range(batch_size):
        # make sure all tokens fit into one block
        prompt_len = i % (model_runner.block_size - 1) + 1
        prompt_lens.append(prompt_len)
        seq_data = list(range(prompt_len))
135
136
137
138
139
140
141
142
143
144
        seq_data = SequenceData(seq_data)
        seq_group_metadata = SequenceGroupMetadata(
            request_id=f"test_{i}",
            is_prompt=False,
            seq_data={0: seq_data},
            sampling_params=SamplingParams(temperature=0),
            block_tables={0: [1]},
        )
        assert seq_group_metadata.token_chunk_size == 1
        seq_group_metadata_list.append(seq_group_metadata)
145

146
    input_tokens, input_positions, attn_metadata, _, _, _ = (
147
148
        model_runner._prepare_decode(seq_group_metadata_list))

youkaichao's avatar
youkaichao committed
149
    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
150
151
    # Verify input metadata is correct for prompts.
    device = model_runner.device
152
153
154
155
156
157
158
159
    assert attn_metadata.is_prompt is False
    assert attn_metadata.prompt_lens is None
    assert attn_metadata.num_prompt_tokens == 0
    assert attn_metadata.num_generation_tokens == expected_bs
    assert attn_metadata.max_prompt_len is None
    assert attn_metadata.subquery_start_loc is None
    assert attn_metadata.seq_start_loc is None
    assert attn_metadata.max_context_len == max(prompt_lens)
160
    assert torch.allclose(
161
        attn_metadata.context_lens[:len(prompt_lens)],
162
163
164
165
        torch.tensor(prompt_lens, dtype=torch.int, device=device))

    # block table's first index corresponds to each batch, meaning in
    # decoding it is each token.
166
    assert attn_metadata.block_tables.shape[0] == len(input_tokens)
167
168
    # Block table's second dim correspondsd to each token's block number.
    # It is padded up to
169
    assert attn_metadata.block_tables.shape[1] == (
170
171
        model_runner.get_max_block_per_batch())
    # Cuda graph should not be used for prerill.
172
173
    assert attn_metadata.use_cuda_graph is True
    assert attn_metadata.kv_cache_dtype == "auto"
174

youkaichao's avatar
youkaichao committed
175
176
    assert input_tokens.shape == (expected_bs, )
    assert input_positions.shape == (expected_bs, )
177
    torch.testing.assert_close(input_tokens, input_positions)
Woosuk Kwon's avatar
Woosuk Kwon committed
178

179
180
181
182
183
184
185
186
187
    # Verify Sampling
    expected_selected_token_indices = []
    selected_token_start_idx = 0
    for prompt_len in prompt_lens:
        expected_selected_token_indices.append(selected_token_start_idx)
        selected_token_start_idx += 1
    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
                                                     prompt_lens,
                                                     subquery_lens=prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
188
    actual = sampling_metadata.selected_token_indices
189
190
191
192
    expected = torch.tensor(expected_selected_token_indices,
                            device=actual.device,
                            dtype=actual.dtype)
    torch.testing.assert_close(actual, expected)