cache_engine.py 3.25 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import torch

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
        if head_size % 16 != 0:
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
23
24
            raise ValueError(
                f'head_size ({head_size}) must be a multiple of 16.')
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
37
38
39
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
44
        self.events = [torch.cuda.Event() for _ in range(num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
45

Woosuk Kwon's avatar
Woosuk Kwon committed
46
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
50
51
52
53
54
55
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
56
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
60
61
62
63
64
        return (
            self.num_heads,
            self.block_size,
            self.head_size,
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
65
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
70
71
72
73
74
75
76
            key_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            value_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
90
91
92
            key_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
        return cpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
    def copy(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
105
            pass