manager.py 7.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Literal

from vllm.v1.kv_offload.abstract import (
    LoadStoreSpec,
    OffloadingEvent,
    OffloadingManager,
10
    OffloadKey,
11
    PrepareStoreOutput,
12
    ReqContext,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
)
from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy
from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy
from vllm.v1.kv_offload.cpu.policies.lru import LRUCachePolicy
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec

_CACHE_POLICIES: dict[str, type[CachePolicy]] = {
    "lru": LRUCachePolicy,
    "arc": ARCCachePolicy,
}


class CPUOffloadingManager(OffloadingManager):
    """
    An OffloadingManager with a pluggable CachePolicy (LRU or ARC).

    The manager owns all shared logic: ref-counting, event emission,
    block pool management, and the prepare_store/complete_store skeletons.
    Policy-specific block organization and eviction decisions are delegated
    to the CachePolicy implementation.
    """

    def __init__(
        self,
        num_blocks: int,
        cache_policy: Literal["lru", "arc"] = "lru",
        enable_events: bool = False,
    ):
        self.medium: str = CPULoadStoreSpec.medium()
        self._num_blocks: int = num_blocks
        self._num_allocated_blocks: int = 0
        self._free_list: list[int] = []
        self.events: list[OffloadingEvent] | None = [] if enable_events else None
        policy_cls = _CACHE_POLICIES.get(cache_policy)
        if policy_cls is None:
            raise ValueError(
                f"Unknown cache policy: {cache_policy!r}. "
                f"Supported: {list(_CACHE_POLICIES)}"
            )
        self._policy: CachePolicy = policy_cls(cache_capacity=num_blocks)

    # --- block pool ---

    def _get_num_free_blocks(self) -> int:
        return len(self._free_list) + self._num_blocks - self._num_allocated_blocks

59
60
61
    def _allocate_blocks(self, keys: list[OffloadKey]) -> list[BlockStatus]:
        num_fresh = min(len(keys), self._num_blocks - self._num_allocated_blocks)
        num_reused = len(keys) - num_fresh
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        assert len(self._free_list) >= num_reused

        # allocate fresh blocks
        blocks: list[BlockStatus] = []
        for _ in range(num_fresh):
            blocks.append(BlockStatus(self._num_allocated_blocks))
            self._num_allocated_blocks += 1

        # allocate reused blocks
        for _ in range(num_reused):
            blocks.append(BlockStatus(self._free_list.pop()))
        return blocks

    def _free_block(self, block: BlockStatus) -> None:
        self._free_list.append(block.block_id)

    def _get_load_store_spec(
        self,
80
        keys: Iterable[OffloadKey],
81
82
83
84
85
86
        blocks: Iterable[BlockStatus],
    ) -> CPULoadStoreSpec:
        return CPULoadStoreSpec([block.block_id for block in blocks])

    # --- OffloadingManager interface ---

87
88
89
90
91
    def lookup(
        self,
        keys: Iterable[OffloadKey],
        req_context: ReqContext,
    ) -> int | None:
92
        hit_count = 0
93
94
        for key in keys:
            block = self._policy.get(key)
95
96
97
98
99
            if block is None or not block.is_ready:
                break
            hit_count += 1
        return hit_count

100
101
102
103
104
    def prepare_load(
        self,
        keys: Iterable[OffloadKey],
        req_context: ReqContext,
    ) -> LoadStoreSpec:
105
        blocks = []
106
107
108
109
        for key in keys:
            block = self._policy.get(key)
            assert block is not None, f"Block {key!r} not found in cache"
            assert block.is_ready, f"Block {key!r} is not ready for reading"
110
111
            block.ref_cnt += 1
            blocks.append(block)
112
        return self._get_load_store_spec(keys, blocks)
113

114
115
    def touch(self, keys: Iterable[OffloadKey]) -> None:
        self._policy.touch(keys)
116

117
118
119
120
121
    def complete_load(self, keys: Iterable[OffloadKey]) -> None:
        for key in keys:
            block = self._policy.get(key)
            assert block is not None, f"Block {key!r} not found"
            assert block.ref_cnt > 0, f"Block {key!r} ref_cnt is already 0"
122
123
            block.ref_cnt -= 1

124
125
126
127
128
    def prepare_store(
        self,
        keys: Iterable[OffloadKey],
        req_context: ReqContext,
    ) -> PrepareStoreOutput | None:
129
        keys_list = list(keys)
130
131

        # filter out blocks that are already stored
132
        keys_to_store = [k for k in keys_list if self._policy.get(k) is None]
133

134
        if not keys_to_store:
135
            return PrepareStoreOutput(
136
                keys_to_store=[],
137
                store_spec=self._get_load_store_spec([], []),
138
                evicted_keys=[],
139
140
            )

141
        num_blocks_to_evict = len(keys_to_store) - self._get_num_free_blocks()
142

143
        to_evict: list[OffloadKey] = []
144
145
146
        if num_blocks_to_evict > 0:
            # Blocks from the original input are excluded from eviction candidates:
            # a block that was already stored must remain in the cache after this call.
147
            protected = set(keys_list)
148
149
150
            evicted = self._policy.evict(num_blocks_to_evict, protected)
            if evicted is None:
                return None
151
            for key, block in evicted:
152
                self._free_block(block)
153
                to_evict.append(key)
154
155
156
157

        if to_evict and self.events is not None:
            self.events.append(
                OffloadingEvent(
158
                    keys=to_evict,
159
160
161
162
163
                    medium=self.medium,
                    removed=True,
                )
            )

164
165
        blocks = self._allocate_blocks(keys_to_store)
        assert len(blocks) == len(keys_to_store), (
166
167
168
            "Block pool did not allocate the expected number of blocks"
        )

169
170
        for key, block in zip(keys_to_store, blocks):
            self._policy.insert(key, block)
171
172

        # build store specs for allocated blocks
173
        store_spec = self._get_load_store_spec(keys_to_store, blocks)
174
175

        return PrepareStoreOutput(
176
            keys_to_store=keys_to_store,
177
            store_spec=store_spec,
178
            evicted_keys=to_evict,
179
180
        )

181
182
    def complete_store(self, keys: Iterable[OffloadKey], success: bool = True) -> None:
        stored_keys: list[OffloadKey] = []
183
184

        if success:
185
186
            for key in keys:
                block = self._policy.get(key)
187
188
                if block is not None and not block.is_ready:
                    block.ref_cnt = 0
189
                    stored_keys.append(key)
190
        else:
191
192
            for key in keys:
                block = self._policy.get(key)
193
                if block is not None and not block.is_ready:
194
                    self._policy.remove(key)
195
196
                    self._free_block(block)

197
        if stored_keys and self.events is not None:
198
199
            self.events.append(
                OffloadingEvent(
200
                    keys=stored_keys,
201
202
203
204
205
206
207
208
209
                    medium=self.medium,
                    removed=False,
                )
            )

    def take_events(self) -> Iterable[OffloadingEvent]:
        if self.events is not None:
            yield from self.events
            self.events.clear()