cache_engine.py 3.24 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import torch

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
        if head_size % 16 != 0:
            raise ValueError(f'head_size ({head_size}) must be a multiple of 16.')

Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
30
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
35
36
37
38
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
44

Woosuk Kwon's avatar
Woosuk Kwon committed
45
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
54
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
55
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
61
62
63
        return (
            self.num_heads,
            self.block_size,
            self.head_size,
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
64
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
69
70
71
72
73
74
75
            key_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            value_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
85
86
87
88
89
90
91
            key_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
        return cpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
    def copy(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
104
            pass