common.py 12.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from collections import deque
4
from dataclasses import dataclass
5
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
6
7
8
9
10
11
12

from vllm.core.block.interfaces import Block, BlockAllocator

BlockId = int
RefCount = int


13
14
15
16
17
18
19
20
21
22
23
24
25
class RefCounterProtocol(Protocol):

    def incr(self, block_id: BlockId) -> RefCount:
        raise NotImplementedError

    def decr(self, block_id: BlockId) -> RefCount:
        raise NotImplementedError

    def get(self, block_id: BlockId) -> RefCount:
        raise NotImplementedError


class RefCounter(RefCounterProtocol):
26
27
28
29
30
31
32
33
34
35
36
37
38
    """A class for managing reference counts for a set of block indices.

    The RefCounter class maintains a dictionary that maps block indices to their
    corresponding reference counts. It provides methods to increment, decrement,
    and retrieve the reference count for a given block index.

    Args:
        all_block_indices (Iterable[BlockId]): An iterable of block indices
            to initialize the reference counter with.
    """

    def __init__(self, all_block_indices: Iterable[BlockId]):
        deduped = set(all_block_indices)
39
40
41
42
        self._refcounts: Dict[BlockId, RefCount] = {
            index: 0
            for index in deduped
        }
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    def incr(self, block_id: BlockId) -> RefCount:
        assert block_id in self._refcounts
        pre_incr_refcount = self._refcounts[block_id]

        assert pre_incr_refcount >= 0

        post_incr_refcount = pre_incr_refcount + 1
        self._refcounts[block_id] = post_incr_refcount
        return post_incr_refcount

    def decr(self, block_id: BlockId) -> RefCount:
        assert block_id in self._refcounts
        refcount = self._refcounts[block_id]

        assert refcount > 0
        refcount -= 1

        self._refcounts[block_id] = refcount

        return refcount

    def get(self, block_id: BlockId) -> RefCount:
        assert block_id in self._refcounts
        return self._refcounts[block_id]

    def as_readonly(self) -> "ReadOnlyRefCounter":
        return ReadOnlyRefCounter(self)


73
class ReadOnlyRefCounter(RefCounterProtocol):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    """A read-only view of the RefCounter class.

    The ReadOnlyRefCounter class provides a read-only interface to access the
    reference counts maintained by a RefCounter instance. It does not allow
    modifications to the reference counts.

    Args:
        refcounter (RefCounter): The RefCounter instance to create a read-only
            view for.
    """

    def __init__(self, refcounter: RefCounter):
        self._refcounter = refcounter

    def incr(self, block_id: BlockId) -> RefCount:
        raise ValueError("Incr not allowed")

    def decr(self, block_id: BlockId) -> RefCount:
        raise ValueError("Decr not allowed")

    def get(self, block_id: BlockId) -> RefCount:
        return self._refcounter.get(block_id)


class CopyOnWriteTracker:
    """A class for tracking and managing copy-on-write operations for blocks.

    The CopyOnWriteTracker class maintains a mapping of source block indices to
        their corresponding copy-on-write destination block indices. It works in
103
        conjunction with a RefCounter.
104
105
106
107
108
109

    Args:
        refcounter (RefCounter): The reference counter used to track block
            reference counts.
    """

110
    def __init__(self, refcounter: RefCounterProtocol):
111
        self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
112
113
        self._refcounter = refcounter

114
115
116
    def is_appendable(self, block: Block) -> bool:
        """Checks if the block is shared or not. If shared, then it cannot
        be appended and needs to be duplicated via copy-on-write
117
118
119
        """
        block_id = block.block_id
        if block_id is None:
120
            return True
121
122

        refcount = self._refcounter.get(block_id)
123
        return refcount <= 1
124

125
126
127
128
129
130
131
132
133
134
135
136
    def record_cow(self, src_block_id: Optional[BlockId],
                   trg_block_id: Optional[BlockId]) -> None:
        """Records a copy-on-write operation from source to target block id
        Args:
            src_block_id (BlockId): The source block id from which to copy 
                the data
            trg_block_id (BlockId): The target block id to which the data
                is copied
        """
        assert src_block_id is not None
        assert trg_block_id is not None
        self._copy_on_writes.append((src_block_id, trg_block_id))
137

138
    def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
139
140
141
        """Clears the copy-on-write tracking information and returns the current
        state.

142
143
        This method returns a list mapping source block indices to
         destination block indices for the current copy-on-write operations.
144
145
146
        It then clears the internal tracking information.

        Returns:
147
148
            List[Tuple[BlockId, BlockId]]: A list mapping source
                block indices to destination block indices for the
149
150
                current copy-on-write operations.
        """
151
152
        cows = self._copy_on_writes
        self._copy_on_writes = []
153
154
155
        return cows


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class BlockPool:
    """Used to pre-allocate block objects, in order to avoid excessive python
    object allocations/deallocations.
    The pool starts from "pool_size" objects and will increase to more objects
    if necessary

    Note that multiple block objects may point to the same physical block id,
    which is why this pool is needed, so that it will be easier to support
    prefix caching and more complicated sharing of physical blocks.
    """

    def __init__(self, block_size: int, create_block: Block.Factory,
                 allocator: BlockAllocator, pool_size: int):
        self._block_size = block_size
        self._create_block = create_block
        self._allocator = allocator
        self._pool_size = pool_size
        assert self._pool_size >= 0

        self._free_ids: Deque[int] = deque(range(self._pool_size))
        self._pool = []
        for i in range(self._pool_size):
            self._pool.append(
                self._create_block(prev_block=None,
                                   token_ids=[],
                                   block_size=self._block_size,
                                   allocator=self._allocator,
183
184
                                   block_id=None,
                                   extra_hash=None))
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

    def increase_pool(self):
        """Doubles the internal pool size
        """
        cur_pool_size = self._pool_size
        new_pool_size = cur_pool_size * 2
        self._pool_size = new_pool_size

        self._free_ids += deque(range(cur_pool_size, new_pool_size))

        for i in range(cur_pool_size, new_pool_size):
            self._pool.append(
                self._create_block(prev_block=None,
                                   token_ids=[],
                                   block_size=self._block_size,
                                   allocator=self._allocator,
201
202
203
204
205
206
207
208
209
                                   block_id=None,
                                   extra_hash=None))

    def init_block(self,
                   prev_block: Optional[Block],
                   token_ids: List[int],
                   block_size: int,
                   physical_block_id: Optional[int],
                   extra_hash: Optional[int] = None) -> Block:
210
211
212
213
214
215
216
217
218
219
220
221
        if len(self._free_ids) == 0:
            self.increase_pool()
            assert len(self._free_ids) > 0

        pool_id = self._free_ids.popleft()

        block = self._pool[pool_id]
        block.__init__(  # type: ignore[misc]
            prev_block=prev_block,
            token_ids=token_ids,
            block_size=block_size,
            allocator=block._allocator,  # type: ignore[attr-defined] 
222
223
            block_id=physical_block_id,
            extra_hash=extra_hash)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        block.pool_id = pool_id  # type: ignore[attr-defined]
        return block

    def free_block(self, block: Block) -> None:
        self._free_ids.appendleft(block.pool_id)  # type: ignore[attr-defined]


class BlockList:
    """This class is an optimization to allow fast-access to physical 
    block ids. It maintains a block id list that is updated with the 
    block list and this avoids the need to reconstruct the block id 
    list on every iteration of the block manager
    """

    def __init__(self, blocks: List[Block]):
        self._blocks: List[Block] = []
        self._block_ids: List[int] = []

        self.update(blocks)

    def _add_block_id(self, block_id: Optional[BlockId]) -> None:
        assert block_id is not None
        self._block_ids.append(block_id)

    def _update_block_id(self, block_index: int,
                         new_block_id: Optional[BlockId]) -> None:
        assert new_block_id is not None
        self._block_ids[block_index] = new_block_id

    def update(self, blocks: List[Block]):
        self._blocks = blocks

        # Cache block ids for fast query
        self._block_ids = []
        for block in self._blocks:
            self._add_block_id(block.block_id)

    def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
        block = self._blocks[block_index]
        prev_block_id = block.block_id

        block.append_token_ids(token_ids)

        # CoW or promotion may update the internal block_id
        if prev_block_id != block.block_id:
            self._update_block_id(block_index, block.block_id)

    def append(self, new_block: Block):
        self._blocks.append(new_block)
        self._add_block_id(new_block.block_id)

    def __len__(self) -> int:
        return len(self._blocks)

    def __getitem__(self, block_index: int) -> Block:
        return self._blocks[block_index]

    def __setitem__(self, block_index: int, new_block: Block) -> None:
        self._blocks[block_index] = new_block
        self._update_block_id(block_index, new_block.block_id)

    def reset(self):
        self._blocks = []
        self._block_ids = []

    def list(self) -> List[Block]:
        return self._blocks

    def ids(self) -> List[int]:
        return self._block_ids


296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@dataclass
class CacheMetricData:
    """A utility dataclass to maintain cache metric.
    To avoid overflow, we maintain the hit rate in block granularity, so that
    we can maintain a single hit rate for n_completed_block x block_size,
    and calculate the real time hit rate by the following:
    BS = The number of queries per block.
    nB = The number of completed blocks.
    HR = hit rate of (nB x BS) queries.
    Q = current number of queries (< BS).
    H = current number of hits (< BS).
    hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
    """
    num_completed_blocks: int = 0
    completed_block_cache_hit_rate: float = 0.0
    num_incompleted_block_queries: int = 0
    num_incompleted_block_hit: int = 0
    block_size: int = 1000

    def query(self, hit: bool):
        self.num_incompleted_block_queries += 1
        self.num_incompleted_block_hit += 1 if hit else 0

        # When a block is completed, update the cache hit rate
        # and reset the incomplete numbers.
        if self.num_incompleted_block_queries == self.block_size:
            hit_rate = (self.num_incompleted_block_hit /
                        self.num_incompleted_block_queries)
            self.completed_block_cache_hit_rate = (
                self.completed_block_cache_hit_rate * self.num_completed_blocks
                + hit_rate) / (self.num_completed_blocks + 1)
            self.num_incompleted_block_queries = 0
            self.num_incompleted_block_hit = 0
            self.num_completed_blocks += 1

    def get_hit_rate(self):
        incomplete_ratio = self.num_incompleted_block_queries / self.block_size
        total_blocks = self.num_completed_blocks + incomplete_ratio
        if total_blocks == 0:
            return 0.0

        completed_block_hit, incompleted_block_hit = 0.0, 0.0
        if self.num_completed_blocks > 0:
            completed_block_hit = (self.completed_block_cache_hit_rate *
                                   self.num_completed_blocks)
        if self.num_incompleted_block_queries > 0:
            incompleted_hit_rate = (self.num_incompleted_block_hit /
                                    self.num_incompleted_block_queries)
            incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
        return (completed_block_hit + incompleted_block_hit) / total_blocks


348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
    """Retrieves all the blocks in a sequence starting from the last block.

    This function recursively traverses the sequence of blocks in reverse order,
    starting from the given last block, and returns a list of all the blocks in
    the sequence.

    Args:
        last_block (Block): The last block in the sequence.

    Returns:
        List[Block]: A list of all the blocks in the sequence, in the order they
            appear.
    """

    def recurse(block: Block, lst: List[Block]) -> None:
        if block.prev_block is not None:
            recurse(block.prev_block, lst)
        lst.append(block)

368
    all_blocks: List[Block] = []
369
370
    recurse(last_block, all_blocks)
    return all_blocks