cache_engine.py 3.25 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import torch

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
        if head_size % 16 != 0:
            raise ValueError(f'head_size ({head_size}) must be a multiple of 16.')

Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
30
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
35
36
37
38
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
44

Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def get_key_block_shape(self) -> Tuple[int, int, int, int, int]:
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

    def get_value_block_shape(self) -> Tuple[int, int, int, int]:
        return (
            self.num_heads,
            self.block_size,
            self.head_size,
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
64
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
69
70
71
72
73
74
75
            key_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            value_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
85
86
87
88
89
90
91
            key_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
        return cpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
    def copy(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
104
            pass