block.py 3.86 KB
Newer Older
1
"""Token blocks."""
2
3
4
import weakref
from collections import defaultdict
from typing import Dict, List
Woosuk Kwon's avatar
Woosuk Kwon committed
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.utils import Device
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
_BLANK_TOKEN_ID = -1
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
DEFAULT_LAST_ACCESSED_TIME = -1

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
TokensBlock = List[int]


class BlockPool:
    """A pool of physical blocks.
    When requests come, we create a lot of logical blocks;
    when requests are done, we destroy a lot of logical blocks.
    It turns out that creating and destroying logical blocks can be expensive,
    especially for the `token_ids` field, which is a list of integers.
    To avoid this overhead, we use a pool to manage the logical blocks.
    When an old request is done and a new request comes, we can reuse the
    logical blocks from the old request to feed the new request.
    """

    def __init__(self) -> None:
        # block size to list of token blocks
        self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)

    def alloc_block(self, block_size: int) -> TokensBlock:
        if block_size in self.pool and self.pool[block_size]:
            return self.pool[block_size].pop()
        return [_BLANK_TOKEN_ID] * block_size

    def del_block(self, block: TokensBlock) -> None:
        self.pool[len(block)].append(block)


_BLOCK_POOL = BlockPool()

Woosuk Kwon's avatar
Woosuk Kwon committed
41
42

class LogicalTokenBlock:
43
44
45
46
47
    """A block that stores a contiguous chunk of tokens from left to right.

    Logical blocks are used to represent the states of the corresponding
    physical blocks in the KV cache.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
48

Woosuk Kwon's avatar
Woosuk Kwon committed
49
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
        self,
        block_number: int,
        block_size: int,
    ) -> None:
        self.block_number = block_number
        self.block_size = block_size

57
58
59
60
61
62
63
        self.token_ids = _BLOCK_POOL.alloc_block(block_size)
        # this finalizer is used to return the block to the pool when the object is deleted # noqa
        # NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
        # i.e. `self.token_ids` may be deleted before `self`, and we lose
        #  the opportunity to return the block to the pool
        self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
                                           self.token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
68
69
70
71
72
73
74
        self.num_tokens = 0

    def is_empty(self) -> bool:
        return self.num_tokens == 0

    def get_num_empty_slots(self) -> int:
        return self.block_size - self.num_tokens

    def is_full(self) -> bool:
        return self.num_tokens == self.block_size

75
    def append_tokens(self, token_ids: List[int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        assert len(token_ids) <= self.get_num_empty_slots()
77
78
        curr_idx = self.num_tokens
        self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
        self.num_tokens += len(token_ids)

    def get_token_ids(self) -> List[int]:
        return self.token_ids[:self.num_tokens]

84
85
86
87
    def get_last_token_id(self) -> int:
        assert self.num_tokens > 0
        return self.token_ids[self.num_tokens - 1]

Woosuk Kwon's avatar
Woosuk Kwon committed
88
89

class PhysicalTokenBlock:
90
    """Represents the state of a block in the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
95
96

    def __init__(
        self,
        device: Device,
        block_number: int,
        block_size: int,
97
98
        block_hash: int,
        num_hashed_tokens: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
    ) -> None:
        self.device = device
        self.block_number = block_number
        self.block_size = block_size
103
104
        self.block_hash = block_hash
        self.num_hashed_tokens = num_hashed_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106

        self.ref_count = 0
107
108
109
        self.last_accessed = DEFAULT_LAST_ACCESSED_TIME

        self.computed = False
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113

    def __repr__(self) -> str:
        return (f'PhysicalTokenBlock(device={self.device}, '
                f'block_number={self.block_number}, '
114
115
116
117
                f'num_hashed_tokens={self.num_hashed_tokens}, '
                f'ref_count={self.ref_count}, '
                f'last_accessed={self.last_accessed}, '
                f'computed={self.computed})')
118
119
120
121


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]