utils.py 8.49 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

6
from tests.v1.kv_connector.unit.utils import MockKVConfig
7
8
from vllm.config import (
    CacheConfig,
9
    ECTransferConfig,
10
11
12
13
14
15
16
17
18
19
20
    KVTransferConfig,
    ModelConfig,
    SchedulerConfig,
    SpeculativeConfig,
    VllmConfig,
)
from vllm.multimodal.inputs import (
    MultiModalFeatureSpec,
    MultiModalKwargsItem,
    PlaceholderRange,
)
21
from vllm.sampling_params import SamplingParams
22
from vllm.utils.hashing import sha256
23
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
24
25
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
26
27
28
29
30
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheConfig,
    KVCacheGroupSpec,
)
31
32
33
34
35
36
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager

EOS_TOKEN_ID = 50256


37
38
39
40
def mock_kv(matched_tokens: int, is_async: bool):
    return MockKVConfig(matched_tokens=matched_tokens, is_async=is_async)


41
42
43
44
def create_scheduler(
    model: str = "facebook/opt-125m",
    max_num_seqs: int = 16,
    max_num_batched_tokens: int = 8192,
45
    enable_chunked_prefill: bool = True,
46
    enable_prefix_caching: bool = False,
47
48
    long_prefill_token_threshold: int = 0,
    disable_chunked_mm_input: bool = False,
49
    use_kv_connector: None | bool | MockKVConfig = None,
50
51
    num_blocks: int = 10000,
    block_size: int = 16,
52
53
    max_model_len: int | None = None,
    num_speculative_tokens: int | None = None,
54
55
    skip_tokenizer_init: bool = False,
    async_scheduling: bool = False,
56
57
    use_ec_connector: bool = False,
    ec_role: str | None = None,
58
) -> Scheduler | AsyncScheduler:
59
    """Create scheduler under test.
60
61
62
63
64
65
66

    Args:
      model: model under test
      max_num_seqs: max sequences to schedule
      max_num_batch_tokens: max num tokens to batch
      enable_prefix_caching: optionally force APC config
                             (True/False) or use default
67
                             (False)
68
69
70

    Returns:
      {class}`Scheduler` instance
71
    """
72
73
74
75
76
77
78
    model_config = ModelConfig(
        model=model,
        trust_remote_code=True,
        dtype="float16",
        seed=42,
        skip_tokenizer_init=skip_tokenizer_init,
    )
79
80
81
82
83
84
85
86
    if max_model_len is None:
        max_model_len = max_num_batched_tokens
    scheduler_config = SchedulerConfig(
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        max_model_len=max_model_len,
        long_prefill_token_threshold=long_prefill_token_threshold,
        disable_chunked_mm_input=disable_chunked_mm_input,
87
        enable_chunked_prefill=enable_chunked_prefill,
88
        async_scheduling=async_scheduling,
89
        is_encoder_decoder=model_config.is_encoder_decoder,
90
91
92
93
94
95
96
    )
    # Cache config, optionally force APC
    cache_config = CacheConfig(
        block_size=block_size,
        gpu_memory_utilization=0.9,
        swap_space=0,
        cache_dtype="auto",
97
        enable_prefix_caching=enable_prefix_caching,
98
    )
99
100
101
102
103
104
105
106
107
108
109
110
    kv_transfer_config = None
    if isinstance(use_kv_connector, MockKVConfig):
        kv_transfer_config = KVTransferConfig(
            kv_connector="MockKVConnector",
            kv_role="kv_both",
            kv_connector_extra_config={
                "matched_tokens": use_kv_connector.matched_tokens,
                "is_async": use_kv_connector.is_async,
            },
        )
    elif use_kv_connector:
        kv_transfer_config = KVTransferConfig(
111
            kv_connector="ExampleConnector",
112
113
114
            kv_role="kv_both",
            kv_connector_extra_config={"shared_storage_path": "local_storage"},
        )
115

116
    speculative_config: SpeculativeConfig | None = None
117
118
    if num_speculative_tokens is not None:
        speculative_config = SpeculativeConfig(
119
120
            model="ngram", num_speculative_tokens=num_speculative_tokens
        )
121

122
123
    ec_transfer_config = (
        ECTransferConfig(
124
            ec_connector="ECExampleConnector",
125
126
127
128
129
130
131
            ec_role=ec_role,
            ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
        )
        if use_ec_connector
        else None
    )

132
133
134
135
136
137
    vllm_config = VllmConfig(
        scheduler_config=scheduler_config,
        model_config=model_config,
        cache_config=cache_config,
        kv_transfer_config=kv_transfer_config,
        speculative_config=speculative_config,
138
        ec_transfer_config=ec_transfer_config,
139
140
141
142
143
    )
    kv_cache_config = KVCacheConfig(
        num_blocks=num_blocks,  # A large number of blocks to hold all requests
        kv_cache_tensors=[],
        kv_cache_groups=[
144
            KVCacheGroupSpec(
145
146
147
148
149
150
151
                ["layer"],
                FullAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=1,
                    head_size=1,
                    dtype=torch.float32,
                ),
152
            )
153
154
155
156
157
158
159
        ],
    )
    cache_config.num_gpu_blocks = num_blocks
    scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
    return scheduler_cls(
        vllm_config=vllm_config,
        kv_cache_config=kv_cache_config,
160
        block_size=block_size,
161
162
163
164
165
        log_stats=True,
        structured_output_manager=StructuredOutputManager(vllm_config),
    )


166
167
168
_none_hash_initialized = False


169
170
171
def create_requests(
    num_requests: int,
    num_tokens: int = 10,
172
    mm_hashes_list: list[list[str]] | None = None,
173
    mm_positions: list[list[PlaceholderRange]] | None = None,
174
    max_tokens: int = 16,
175
176
    stop_token_ids: list[int] | None = None,
    prompt_logprobs: int | None = None,
177
    same_prompt: bool = False,
178
    block_size: int = 16,
179
    req_ids: list[str] | None = None,
180
) -> list[Request]:
181
182
    global _none_hash_initialized
    if not _none_hash_initialized:
183
        init_none_hash(sha256)
184
185
        _none_hash_initialized = True

186
    block_hasher = get_request_block_hasher(block_size, sha256)
187
188
189
190
191
192
    sampling_params = SamplingParams(
        ignore_eos=False,
        max_tokens=max_tokens,
        stop_token_ids=stop_token_ids,
        prompt_logprobs=prompt_logprobs,
    )
193
    requests = []
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    if mm_hashes_list is not None:
        # NOTE: allow manual input; some mm items can have the same identifier
        # no. of mm_hashes and mm_positions for each request should be identical
        assert mm_positions is not None, (
            "mm_positions must be provided when mm_hashes_list is provided"
        )
        assert len(mm_hashes_list) == len(mm_positions) == num_requests
        assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]

        # Since same identifier would imply they are identical encoder output
        # Verify mm items with identical identifier are having mm_position.length
        seen_hashes: dict[str, int] = {}

    if req_ids:
        assert len(req_ids) == num_requests
    else:
        req_ids = [f"{i}" for i in range(num_requests)]

213
    for i in range(num_requests):
214
        mm_features = []
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

        for j, position in enumerate(
            mm_positions[i] if mm_positions is not None else []
        ):
            if mm_hashes_list is not None:
                identifier = mm_hashes_list[i][j]

                # Verify if position length is identical
                position_length = position.length
                if identifier in seen_hashes:
                    assert seen_hashes[identifier] == position_length, (
                        f"mm_hash '{identifier}' has inconsistent position lengths: "
                        f"previously {seen_hashes[identifier]}, now {position_length} "
                        f"at request {i}, position {j}"
                    )
                else:
                    seen_hashes[identifier] = position_length
            else:
                # Unique dummy hash for each mm item
234
                identifier = f"hash{i}_{j}"
235
236
237
238
239
240
241
            mm_feature = MultiModalFeatureSpec(
                data=MultiModalKwargsItem.dummy("dummy_m"),
                mm_position=position,
                identifier=identifier,
                modality="image",
            )
            mm_features.append(mm_feature)
242

243
        prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
244
        request = Request(
245
            request_id=req_ids[i],
246
247
248
            prompt_token_ids=prompt_token_ids,
            sampling_params=sampling_params,
            pooling_params=None,
249
            mm_features=mm_features if mm_features else None,
250
            eos_token_id=EOS_TOKEN_ID,
251
            block_hasher=block_hasher,
252
253
254
        )
        requests.append(request)
    return requests