utils.py 12.5 KB
Newer Older
Robert Shaw's avatar
Robert Shaw committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import tempfile
from collections import defaultdict
5
from collections.abc import Callable
6
7
from dataclasses import dataclass
from itertools import chain, count
8
from typing import Any
Robert Shaw's avatar
Robert Shaw committed
9
10
11
12

import torch

from vllm import SamplingParams
13
14
15
16
17
18
19
20
21
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    KVTransferConfig,
    ModelConfig,
    SchedulerConfig,
    VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
22
23
24
25
26
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
    KVConnectorBase_V1,
    KVConnectorMetadata,
    KVConnectorRole,
)
27
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (  # noqa
28
29
    SharedStorageConnector,
)
30
from vllm.utils.hashing import sha256
31
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
32
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
33
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
34
35
36
37
38
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheConfig,
    KVCacheGroupSpec,
)
39
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
Robert Shaw's avatar
Robert Shaw committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager

EOS_TOKEN_ID = 50256


def assert_scheduler_empty(scheduler: Scheduler):
    """Confirm the scheduler is "empty" - i.e. no leaks."""
    # Scheduler Metadata.
    assert len(scheduler.requests) == 0
    assert len(scheduler.waiting) == 0
    assert len(scheduler.running) == 0
    assert len(scheduler.finished_req_ids) == 0
    assert len(scheduler.finished_recving_kv_req_ids) == 0

    # EncoderCacheManager.
    assert len(scheduler.encoder_cache_manager.freed) == 0
    assert len(scheduler.encoder_cache_manager.cached) == 0

    # KVCache Manager.
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    assert (
        len(
            scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
        )
        == 0
    )
    assert (
        len(
            scheduler.kv_cache_manager.coordinator.single_type_managers[
                0
            ].num_cached_block
        )
        == 0
    )
Robert Shaw's avatar
Robert Shaw committed
74
    num_free_blocks = (
75
76
77
        scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
    )
    assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
Robert Shaw's avatar
Robert Shaw committed
78
79
80
81
82
83
84
85
86
87
88
89

    # NOTE(rob): just the ref count on blocks will be 0. The hash
    # value, etc will remain since we lazily evict for prefix cache.
    for block in scheduler.kv_cache_manager.block_pool.blocks:
        assert block.ref_cnt == 0


def create_vllm_config(
    model: str = "facebook/opt-125m",
    max_num_seqs: int = 16,
    max_num_batched_tokens: int = 64,
    block_size: int = 16,
90
91
    max_model_len: int = 10000,
    enable_chunked_prefill: bool = True,
92
    enable_permute_local_kv: bool = False,
93
94
95
96
    kv_connector_extra_config: dict[str, Any] | None = None,
    dtype: str = "float16",
    cache_dtype: str = "auto",
    hf_overrides: dict[str, Any] | None = None,
Robert Shaw's avatar
Robert Shaw committed
97
98
99
100
101
) -> VllmConfig:
    """Initialize VllmConfig For Testing."""
    model_config = ModelConfig(
        model=model,
        trust_remote_code=True,
102
        dtype=dtype,
Robert Shaw's avatar
Robert Shaw committed
103
        seed=42,
104
        hf_overrides=hf_overrides or {},
Robert Shaw's avatar
Robert Shaw committed
105
    )
106
107
108
109
110
111
112
    scheduler_config = SchedulerConfig(
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        max_model_len=max_model_len,
        enable_chunked_prefill=enable_chunked_prefill,
        is_encoder_decoder=model_config.is_encoder_decoder,
    )
Robert Shaw's avatar
Robert Shaw committed
113
114
115
116
117
    # Cache config, optionally force APC
    cache_config = CacheConfig(
        block_size=block_size,
        gpu_memory_utilization=0.9,
        swap_space=0,
118
        cache_dtype=cache_dtype,
Robert Shaw's avatar
Robert Shaw committed
119
120
121
122
123
        enable_prefix_caching=True,
    )
    kv_transfer_config = KVTransferConfig(
        kv_connector="NixlConnector",
        kv_role="kv_both",
124
        enable_permute_local_kv=enable_permute_local_kv,
125
        kv_connector_extra_config=kv_connector_extra_config or {},
Robert Shaw's avatar
Robert Shaw committed
126
    )
127
128
129
130
131
132
133
    return VllmConfig(
        scheduler_config=scheduler_config,
        model_config=model_config,
        cache_config=cache_config,
        kv_transfer_config=kv_transfer_config,
        device_config=DeviceConfig("cpu"),
    )
Robert Shaw's avatar
Robert Shaw committed
134
135
136
137
138
139
140
141
142
143


def create_scheduler(
    vllm_config: VllmConfig,
    num_blocks: int = 10000,
) -> Scheduler:
    """Initialize Scheduler For Testing."""
    block_size = vllm_config.cache_config.block_size
    kv_cache_config = KVCacheConfig(
        num_blocks=num_blocks,  # A large number of blocks to hold all requests
144
        kv_cache_tensors=[],
Robert Shaw's avatar
Robert Shaw committed
145
        kv_cache_groups=[
146
147
148
            KVCacheGroupSpec(
                ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
            )
Robert Shaw's avatar
Robert Shaw committed
149
150
151
152
153
154
155
156
        ],
    )
    vllm_config.cache_config.num_gpu_blocks = num_blocks
    return Scheduler(
        vllm_config=vllm_config,
        kv_cache_config=kv_cache_config,
        log_stats=True,
        structured_output_manager=StructuredOutputManager(vllm_config),
157
        block_size=block_size,
Robert Shaw's avatar
Robert Shaw committed
158
159
160
    )


161
_request_count = count(1)
162
163
164
_none_hash_initialized = False


165
def create_request(
166
    request_id: int | None = None,
167
168
169
170
171
172
173
174
175
    num_tokens: int = 10,
    common_prefix_len=0,
    max_tokens: int = 16,
    do_remote_decode: bool = False,
    do_remote_prefill: bool = False,
    num_remote_blocks: int = 3,
    block_size: int = 16,
    hash_fn: Callable = sha256,
) -> Request:
Robert Shaw's avatar
Robert Shaw committed
176
    """Make dummy request for testing."""
177
178
179
180
181
    assert num_tokens >= common_prefix_len >= 0

    if request_id is None:
        request_id = next(_request_count)

182
183
    global _none_hash_initialized
    if not _none_hash_initialized:
184
        init_none_hash(hash_fn)
185
        _none_hash_initialized = True
Robert Shaw's avatar
Robert Shaw committed
186

187
    kv_transfer_params: dict[str, Any] | None = None
188

Robert Shaw's avatar
Robert Shaw committed
189
190
    if do_remote_decode:
        assert not do_remote_prefill
191
        kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
Robert Shaw's avatar
Robert Shaw committed
192
    elif do_remote_prefill:
193
194
195
196
        kv_transfer_params = dict(
            do_remote_prefill=True,
            do_remote_decode=False,
            remote_engine_id="my-engine-id",
197
            remote_request_id=f"prefill-{request_id}",
198
199
200
201
            remote_block_ids=list(range(num_remote_blocks)),
            remote_host="my-host",
            remote_port=1234,
        )
Robert Shaw's avatar
Robert Shaw committed
202
203
204
205

    max_tokens = 1 if do_remote_decode else max_tokens
    sampling_params = SamplingParams(max_tokens=max_tokens)

206
207
208
    common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
    suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
    prompt_token_ids = common_prefix + suffix
Robert Shaw's avatar
Robert Shaw committed
209
210
211
212
213

    req = Request(
        request_id=f"id-{request_id}",
        prompt_token_ids=prompt_token_ids,
        sampling_params=sampling_params,
214
        pooling_params=None,
215
        mm_features=None,
Robert Shaw's avatar
Robert Shaw committed
216
        eos_token_id=EOS_TOKEN_ID,
217
        block_hasher=get_request_block_hasher(block_size, hash_fn),
Robert Shaw's avatar
Robert Shaw committed
218
219
220
221
222
223
224
    )
    req.kv_transfer_params = kv_transfer_params
    return req


def create_model_runner_output(
    reqs: list[Request],
225
226
227
    finished_sending: set[str] | None = None,
    finished_recving: set[str] | None = None,
    invalid_block_ids: set[int] | None = None,
Robert Shaw's avatar
Robert Shaw committed
228
    use_eos: bool = False,
229
    token_id: int = 0,
Robert Shaw's avatar
Robert Shaw committed
230
231
232
233
234
235
236
237
) -> ModelRunnerOutput:
    """Make dummy model runner output for testing."""

    # Make request data.
    req_ids = [req.request_id for req in reqs]
    req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}

    # Make sampled tokens.
238
    sampled_token = EOS_TOKEN_ID if use_eos else token_id
239
    sampled_token_ids = [[sampled_token] for _ in req_ids]
Robert Shaw's avatar
Robert Shaw committed
240

241
242
243
244
245
246
247
248
    kv_connector_output = (
        None
        if (
            finished_sending is None
            and finished_recving is None
            and invalid_block_ids is None
        )
        else KVConnectorOutput(
249
250
            finished_sending=finished_sending,
            finished_recving=finished_recving,
251
            invalid_block_ids=invalid_block_ids or set(),
252
        )
253
    )
254

Robert Shaw's avatar
Robert Shaw committed
255
256
257
258
259
260
261
    # Make output data structure.
    return ModelRunnerOutput(
        req_ids=req_ids,
        req_id_to_index=req_id_to_index,
        sampled_token_ids=sampled_token_ids,
        logprobs=None,
        prompt_logprobs_dict={},
262
        pooler_output=None,
263
        kv_connector_output=kv_connector_output,
Robert Shaw's avatar
Robert Shaw committed
264
    )
265
266
267


class TestSharedStorageConnector(SharedStorageConnector):
268
    def __init__(self, config: VllmConfig, role, kv_cache_config):
269
270
271
272
        self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
        self._connector = SharedStorageConnector(config, role)
        self.call_record: dict[str, int] = defaultdict(int)
        # Use a unique temp file per connector
273
274
275
276
        self._event_file = (
            tempfile.gettempdir()
            + f"/connector_{self.name}-{self.role.name}_events.log"
        )
277
278
279
280
281
        # Start with an empty file
        with open(self._event_file, "w") as _:
            pass

    def __getattribute__(self, name):
282
283
284
285
286
287
288
289
290
291
        if name in (
            "_connector",
            "call_record",
            "name",
            "_event_file",
            "__class__",
            "__dict__",
            "__getattribute__",
            "__init__",
        ):  # avoid recursion
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            return object.__getattribute__(self, name)
        if not hasattr(self._connector, name):
            return object.__getattribute__(self, name)
        attr = getattr(self._connector, name)

        # Intercept calls to the connector interface and write an event
        # for each one to a file, which can be read back in the main test proc.
        if callable(attr):

            def wrapper(*args, **kwargs):
                self.call_record[name] += 1

                # Include args that we're interested in
                to_log = [name]
                for arg in args:
                    if isinstance(arg, int):
                        to_log.append(str(arg))
                    elif isinstance(arg, KVCacheBlocks):
310
                        to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
311
312
313
314

                # Log the event as a line to the file
                try:
                    with open(self._event_file, "a") as f:
315
                        f.write(" ".join(to_log) + "\n")
316
                except Exception as e:
317
                    print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
318
319
320
321
322
323
                return attr(*args, **kwargs)

            return wrapper
        return attr


324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
@dataclass(frozen=True)
class MockKVConfig:
    matched_tokens: int = 0
    is_async: bool = False


class MockKVConnectorMetadata(KVConnectorMetadata):
    def __init__(self):
        # Scheduler tests check metadata.requests
        self.requests: list = []


class MockKVConnector(KVConnectorBase_V1):
    """Mock KV connector for scheduler tests, supporting both sync and async mode."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        role: KVConnectorRole,
        kv_cache_config: KVCacheConfig | None = None,
    ):
        super().__init__(vllm_config, role, kv_cache_config)
        extra_config = self._kv_transfer_config.kv_connector_extra_config
        self.config = MockKVConfig(
            matched_tokens=extra_config["matched_tokens"],
            is_async=extra_config["is_async"],
        )

    def get_num_new_matched_tokens(
        self,
        request: Request,
        num_computed_tokens: int,
    ) -> tuple[int | None, bool]:
        return (self.config.matched_tokens, self.config.is_async)

    def update_state_after_alloc(
        self,
        request: Request,
        blocks: KVCacheBlocks,
        num_external_tokens: int,
    ):
        pass

    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        metadata = MockKVConnectorMetadata()
        cached_reqs = scheduler_output.scheduled_cached_reqs
        for req_id in chain(
            (req.req_id for req in scheduler_output.scheduled_new_reqs),
            (
                req_id
                for req_id in cached_reqs.req_ids
                if req_id in cached_reqs.resumed_req_ids
            ),
        ):
            metadata.requests.append({"req_id": req_id})
        return metadata

    def start_load_kv(self, kv_caches, finished_req_ids):
        pass

    def wait_for_layer_load(self, layer_name):
        pass

    def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
        pass

    def wait_for_save(self):
        pass


396
397
398
KVConnectorFactory.register_connector(
    "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
)
399
400
401
402

KVConnectorFactory.register_connector(
    "MockKVConnector", __name__, MockKVConnector.__name__
)