interfaces.py 4.78 KB
Newer Older
1
2
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol
3
4
5

from vllm.utils import Device

6
7
BlockId = int

8
9
10
11
12
13
14

class Block(ABC):

    @abstractmethod
    def append_token_ids(self, token_ids: List[int]) -> None:
        pass

15
16
    @property
    @abstractmethod
17
18
19
    def block_id(self) -> Optional[int]:
        pass

20
21
22
23
24
25
    @block_id.setter
    @abstractmethod
    def block_id(self, value: Optional[int]) -> None:
        """NOTE: Do not use this API outside Block."""
        self._block_id = value

26
27
    @property
    @abstractmethod
28
29
30
    def token_ids(self) -> List[int]:
        pass

31
32
    @property
    @abstractmethod
33
34
35
    def num_empty_slots(self) -> int:
        pass

36
37
    @property
    @abstractmethod
38
39
40
    def is_full(self) -> bool:
        pass

41
42
    @property
    @abstractmethod
43
44
45
    def prev_block(self) -> Optional["Block"]:
        pass

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    @property
    @abstractmethod
    def computed(self) -> bool:
        raise NotImplementedError

    @computed.setter
    @abstractmethod
    def computed(self, value) -> bool:
        """Should be only used by PrefixCacingAllocator"""
        raise NotImplementedError

    @property
    @abstractmethod
    def last_accessed(self) -> float:
        raise NotImplementedError

    @last_accessed.setter
    @abstractmethod
    def last_accessed(self, last_accessed_ts: float):
        raise NotImplementedError

67
68
69
70
71
72
73
74
75
76
77
78
79
    class Factory(Protocol):

        @abstractmethod
        def __call__(
            self,
            prev_block: Optional["Block"],
            token_ids: List[int],
            block_size: int,
            allocator: "BlockAllocator",
            block_id: Optional[int] = None,
        ) -> "Block":
            pass

80
81
82
83
84
85
86
87
88
89
90
    @property
    @abstractmethod
    def content_hash(self) -> Optional[int]:
        """Return the content-based hash of the current block, or None if it is
        not yet defined or not supported.

        For the content-based hash to be defined, the current block must be
        full.
        """
        return None

91
92
93
94

class BlockAllocator(ABC):

    @abstractmethod
95
    def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
96
97
98
99
        pass

    @abstractmethod
    def allocate_immutable(self, prev_block: Optional[Block],
100
                           token_ids: List[int]) -> Block:
101
102
103
104
105
106
107
108
109
110
        pass

    @abstractmethod
    def free(self, block: Block) -> None:
        pass

    @abstractmethod
    def fork(self, last_block: Block) -> List[Block]:
        pass

111
112
113
114
    @abstractmethod
    def get_num_total_blocks(self) -> int:
        pass

115
    @abstractmethod
116
    def get_num_free_blocks(self) -> int:
117
118
        pass

119
120
121
    @property
    @abstractmethod
    def all_block_ids(self) -> FrozenSet[int]:
122
123
124
125
126
127
        pass

    @abstractmethod
    def clear_copy_on_writes(self) -> Dict[int, List[int]]:
        pass

128
    @abstractmethod
129
130
    def mark_blocks_as_accessed(self, block_ids: List[int],
                                now: float) -> None:
131
132
        pass

133
    @abstractmethod
134
    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
135
136
137
138
139
140
141
        pass

    @abstractmethod
    def get_common_computed_block_ids(
            self, seq_block_ids: List[List[int]]) -> List[int]:
        pass

142
143
144
145
146
147
148
149
150
151
    @abstractmethod
    def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
        """NOTE: This should not be used besides Block"""
        pass

    @abstractmethod
    def promote_to_immutable_block(self, block: Block) -> BlockId:
        """NOTE: This should not be used besides Block"""
        pass

152
153
154
155
    class NoFreeBlocksError(ValueError):
        pass


156
class DeviceAwareBlockAllocator(ABC):
157
158

    @abstractmethod
159
160
161
162
163
164
165
    def allocate_mutable(self, prev_block: Optional[Block],
                         device: Device) -> Block:
        pass

    @abstractmethod
    def allocate_immutable(self, prev_block: Optional[Block],
                           token_ids: List[int], device: Device) -> Block:
166
167
168
        pass

    @abstractmethod
169
    def get_num_free_blocks(self, device: Device) -> int:
170
171
172
        pass

    @abstractmethod
173
    def get_num_total_blocks(self, device: Device) -> int:
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        pass

    @abstractmethod
    def free(self, block: Block) -> None:
        pass

    @abstractmethod
    def fork(self, last_block: Block) -> List[Block]:
        pass

    @property
    @abstractmethod
    def all_block_ids(self) -> FrozenSet[int]:
        pass

    @abstractmethod
    def clear_copy_on_writes(self) -> Dict[int, List[int]]:
        pass

    @abstractmethod
    def mark_blocks_as_accessed(self, block_ids: List[int],
                                now: float) -> None:
        pass

    @abstractmethod
    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
        pass

    @abstractmethod
    def get_common_computed_block_ids(
            self, seq_block_ids: List[List[int]]) -> List[int]:
205
        pass