cache_engine.py 3.5 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import List, Tuple

import torch

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
        block_size: int,
        dtype: torch.dtype = torch.float16,
    ) -> None:
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.block_size = block_size
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

        # Initialize the streams.
        self.copy_stream = torch.cuda.Stream(device=gpu_id)
        self.swap_stream = torch.cuda.Stream(device=gpu_id)
        assert self.copy_stream != self.swap_stream
        current_stream = torch.cuda.current_stream(device=gpu_id)
        assert self.copy_stream != current_stream
        assert self.swap_stream != current_stream

        # Initialize the events for synchronization.

    def allocate_gpu_cache(self) -> List[List[KVCache]]:
        gpu_cache: List[List[KVCache]] = []
        for _ in range(self.num_layers):
            layer_cache: List[KVCache] = []
            for _ in range(self.num_heads):
                key_blocks = torch.empty(
                    (self.num_gpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    device=self.gpu_id,
                )
                value_blocks = torch.empty(
                    (self.num_gpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    device=self.gpu_id,
                )
                layer_cache.append((key_blocks, value_blocks))
            gpu_cache.append(layer_cache)
        return gpu_cache

    def allocate_cpu_cache(self) -> List[List[KVCache]]:
        cpu_cache: List[List[KVCache]] = []
        for _ in range(self.num_layers):
            layer_cache: List[KVCache] = []
            for _ in range(self.num_heads):
                key_blocks = torch.empty(
                    (self.num_cpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    pin_memory=True,
                )
                value_blocks = torch.empty(
                    (self.num_cpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    pin_memory=True,
                )
                layer_cache.append((key_blocks, value_blocks))
            cpu_cache.append(layer_cache)
        return cpu_cache

    def copy(
        self,
        src_block_numbers: List[int],
        dst_block_numbers: List[int],
    ) -> None:
        for layer in range(self.num_layers):
            # TODO: Call the COPY op.
            pass

    def swap_out(
        self,
        gpu_block_numbers: List[int],
        cpu_block_numbers: List[int],
    ) -> None:
        for layer in range(self.num_layers):
            # TODO: Call the SWAP_OUT op on the swap stream.
            pass

    def swap_in(
        self,
        gpu_block_numbers: List[int],
        cpu_block_numbers: List[int],
    ) -> None:
        for layer in range(self.num_layers):
            # TODO: Call the SWAP_IN op on the swap stream.
            pass