block.py 2.46 KB
Newer Older
1
"""Token blocks."""
2
from typing import TYPE_CHECKING, Iterator, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
from vllm.utils import Device
Woosuk Kwon's avatar
Woosuk Kwon committed
5

6
DEFAULT_LAST_ACCESSED_TIME: float = -1
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9

class PhysicalTokenBlock:
10
    """Represents the state of a block in the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16

    def __init__(
        self,
        device: Device,
        block_number: int,
        block_size: int,
17
18
        block_hash: int,
        num_hashed_tokens: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21
22
    ) -> None:
        self.device = device
        self.block_number = block_number
        self.block_size = block_size
23
24
        self.block_hash = block_hash
        self.num_hashed_tokens = num_hashed_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26

        self.ref_count = 0
27
28
29
        self.last_accessed = DEFAULT_LAST_ACCESSED_TIME

        self.computed = False
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33

    def __repr__(self) -> str:
        return (f'PhysicalTokenBlock(device={self.device}, '
                f'block_number={self.block_number}, '
34
35
36
37
                f'num_hashed_tokens={self.num_hashed_tokens}, '
                f'ref_count={self.ref_count}, '
                f'last_accessed={self.last_accessed}, '
                f'computed={self.computed})')
38
39


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class BlockTable:
    """Holds a list of blocks with caching of their associated block_ids 
    """

    def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
        self._blocks: List[PhysicalTokenBlock] = []
        self._block_ids: List[int] = []

        if blocks is not None:
            for block in blocks:
                self.append(block)

    def append(self, block: PhysicalTokenBlock):
        self._blocks.append(block)
        self._block_ids.append(block.block_number)

    def __len__(self) -> int:
        return len(self._blocks)

    def __getitem__(self, key):
        return self._blocks[key]

62
63
64
65
66
    if TYPE_CHECKING:

        def __iter__(self) -> Iterator[PhysicalTokenBlock]:
            raise RuntimeError("Method should be automatically generated")

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def __setitem__(self, key, value):
        if isinstance(key, slice):
            blocks = value
            self._blocks[key] = blocks
            self._block_ids[key] = [b.block_number for b in blocks]
        else:
            block = value
            self._blocks[key] = block
            self._block_ids[key] = block.block_number

    def reset(self):
        self._blocks = []
        self._block_ids = []

    def copy(self) -> "BlockTable":
        return BlockTable(self._blocks)

    def list(self) -> List[PhysicalTokenBlock]:
        return self._blocks

    def ids(self) -> List[int]:
        return self._block_ids