cache_manager.py 4.38 KB
Newer Older
1
2
3
4
import math
import torch

from typing import Optional, List, Tuple
5
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

BLOCK_SIZE: int = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None


class CacheManager:
    def __init__(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        repeat_slots: bool,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.block_size = BLOCK_SIZE
        self.num_blocks = num_blocks
        self.repeat_slots = repeat_slots

        element_size = torch.tensor([], dtype=dtype).element_size()
28
29
30
31
        if IS_XPU_SYSTEM:
            x = 1
        else:
            x = self.block_size // element_size
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, self.block_size, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, self.block_size),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]
        self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
        self.slots = torch.arange(
50
            0, num_blocks * self.block_size, dtype=torch.int64
51
52
53
54
55
56
57
58
59
60
61
        ).view(num_blocks, self.block_size)

    def allocate(
        self,
        needed_blocks_slots: List[Tuple[int, int]],
        blocks: int,
        max_blocks: int,
        device: torch.device,
    ):
        # Get free blocks indices by finding values in mask that are not set to 0
        free_block_indices = self.free_block_mask.nonzero()
OlivierDehaene's avatar
OlivierDehaene committed
62
63
64
65
        if blocks > len(free_block_indices):
            raise RuntimeError(
                f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
            )
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

        # Slice by the number of required blocks
        block_indices = free_block_indices[:blocks]
        block_indices = block_indices.flatten()

        # Padded block tables
        block_tables_tensor = torch.zeros(
            (len(needed_blocks_slots), max_blocks), dtype=torch.int32
        )

        # Allocate paged attention blocks
        cumulative_blocks = 0
        slots = []
        block_tables = []
        for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
            # Get allocated blocks for this sequence
            allocated_blocks = block_indices[
                cumulative_blocks : cumulative_blocks + needed_blocks
            ]
            # Get slots for the allocated blocks
            all_slots = self.slots[allocated_blocks].flatten()

            # Repeat slots in the case of context sliding window
            if needed_slots > len(all_slots) and self.repeat_slots:
                repeats = math.ceil(needed_slots / len(all_slots))
                all_slots = all_slots.repeat(repeats)

            allocated_slots = all_slots[:needed_slots]

            slots.append(allocated_slots)
            block_tables.append(allocated_blocks.tolist())
            block_tables_tensor[i, :needed_blocks] = allocated_blocks
            cumulative_blocks += needed_blocks

        block_tables = block_tables
        block_tables_tensor = block_tables_tensor.to(device)
        slots = torch.concat(slots).to(device)

        # Allocate the required number of blocks by setting the mask to 0
        self.free_block_mask[block_indices] = 0

        return block_tables, block_tables_tensor, slots

    def free(self, block_indices: Optional[List[int]]):
        if block_indices is not None and block_indices:
            # Reset mask
            self.free_block_mask[block_indices] = 1


def set_cache_manager(
    num_blocks: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    repeat_slots: bool,
    dtype: torch.dtype,
    device: torch.device,
) -> CacheManager:
    global CACHE_MANAGER
    if CACHE_MANAGER is not None:
        del CACHE_MANAGER
        torch.cuda.empty_cache()

    CACHE_MANAGER = CacheManager(
        num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
    )
    return CACHE_MANAGER


def get_cache_manager() -> CacheManager:
    global CACHE_MANAGER
    if CACHE_MANAGER is None:
        raise RuntimeError("cache manager was not initialized")

    return CACHE_MANAGER