common.py 6.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from collections import defaultdict
from typing import Dict, Iterable, List, Optional

from vllm.core.block.interfaces import Block, BlockAllocator

BlockId = int
RefCount = int


class RefCounter:
    """A class for managing reference counts for a set of block indices.

    The RefCounter class maintains a dictionary that maps block indices to their
    corresponding reference counts. It provides methods to increment, decrement,
    and retrieve the reference count for a given block index.

    Args:
        all_block_indices (Iterable[BlockId]): An iterable of block indices
            to initialize the reference counter with.
    """

    def __init__(self, all_block_indices: Iterable[BlockId]):
        deduped = set(all_block_indices)
        self._refcounts: Dict[BlockId,
                              RefCount] = {index: 0
                                           for index in deduped}

    def incr(self, block_id: BlockId) -> RefCount:
        assert block_id in self._refcounts
        pre_incr_refcount = self._refcounts[block_id]

        assert pre_incr_refcount >= 0

        post_incr_refcount = pre_incr_refcount + 1
        self._refcounts[block_id] = post_incr_refcount
        return post_incr_refcount

    def decr(self, block_id: BlockId) -> RefCount:
        assert block_id in self._refcounts
        refcount = self._refcounts[block_id]

        assert refcount > 0
        refcount -= 1

        self._refcounts[block_id] = refcount

        return refcount

    def get(self, block_id: BlockId) -> RefCount:
        assert block_id in self._refcounts
        return self._refcounts[block_id]

    def as_readonly(self) -> "ReadOnlyRefCounter":
        return ReadOnlyRefCounter(self)


class ReadOnlyRefCounter:
    """A read-only view of the RefCounter class.

    The ReadOnlyRefCounter class provides a read-only interface to access the
    reference counts maintained by a RefCounter instance. It does not allow
    modifications to the reference counts.

    Args:
        refcounter (RefCounter): The RefCounter instance to create a read-only
            view for.
    """

    def __init__(self, refcounter: RefCounter):
        self._refcounter = refcounter

    def incr(self, block_id: BlockId) -> RefCount:
        raise ValueError("Incr not allowed")

    def decr(self, block_id: BlockId) -> RefCount:
        raise ValueError("Decr not allowed")

    def get(self, block_id: BlockId) -> RefCount:
        return self._refcounter.get(block_id)


class CopyOnWriteTracker:
    """A class for tracking and managing copy-on-write operations for blocks.

    The CopyOnWriteTracker class maintains a mapping of source block indices to
        their corresponding copy-on-write destination block indices. It works in
        conjunction with a RefCounter and a BlockAllocator to handle reference
        counting and block allocation.

    Args:
        refcounter (RefCounter): The reference counter used to track block
            reference counts.
        allocator (BlockAllocator): The block allocator used to allocate and
            free blocks.
    """

    def __init__(
        self,
        refcounter: RefCounter,
        allocator: BlockAllocator,
    ):
        self._copy_on_writes = defaultdict(list)
        self._refcounter = refcounter
        self._allocator = allocator

    def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
        """Performs a copy-on-write operation on the given block if it is not
        appendable.

        This method checks the reference count of the given block. If the
        reference count is greater than 1, indicating that the block is shared,
        a copy-on-write operation is performed. The original block is freed,
        and a new block is allocated with the same content. The new block index
        is returned.

        Args:
            block (Block): The block to check for copy-on-write.

        Returns:
            Optional[BlockId]: The block index of the new block if a copy-on
                -write operation was performed, or the original block index if
                no copy-on-write was necessary.
        """
        block_id = block.block_id
        if block_id is None:
            return block_id

        refcount = self._refcounter.get(block_id)
        assert refcount != 0
        if refcount > 1:
            src_block_id = block_id

            # Decrement refcount of the old block.
            self._allocator.free(block)

            # Allocate a fresh new block.
            block_id = self._allocator.allocate_mutable(
                prev_block=block.prev_block).block_id

            # Track src/dst copy.
            self._copy_on_writes[src_block_id].append(block_id)

        return block_id

    def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
        """Clears the copy-on-write tracking information and returns the current
        state.

        This method returns a dictionary mapping source block indices to lists
        of destination block indices for the current copy-on-write operations.
        It then clears the internal tracking information.

        Returns:
            Dict[BlockId, List[BlockId]]: A dictionary mapping source
                block indices to lists of destination block indices for the
                current copy-on-write operations.
        """
        cows = dict(self._copy_on_writes)
        self._copy_on_writes.clear()
        return cows


def get_all_blocks_recursively(last_block: Block) -> List[Block]:
    """Retrieves all the blocks in a sequence starting from the last block.

    This function recursively traverses the sequence of blocks in reverse order,
    starting from the given last block, and returns a list of all the blocks in
    the sequence.

    Args:
        last_block (Block): The last block in the sequence.

    Returns:
        List[Block]: A list of all the blocks in the sequence, in the order they
            appear.
    """

    def recurse(block: Block, lst: List[Block]) -> None:
        if block.prev_block is not None:
            recurse(block.prev_block, lst)
        lst.append(block)

    all_blocks = []
    recurse(last_block, all_blocks)
    return all_blocks