cache_engine.py 3.02 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import torch

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
    ) -> None:
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
27
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    def allocate_gpu_cache(self) -> List[List[KVCache]]:
        gpu_cache: List[List[KVCache]] = []
        for _ in range(self.num_layers):
            layer_cache: List[KVCache] = []
            for _ in range(self.num_heads):
                key_blocks = torch.empty(
                    (self.num_gpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    device=self.gpu_id,
                )
                value_blocks = torch.empty(
                    (self.num_gpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    device=self.gpu_id,
                )
                layer_cache.append((key_blocks, value_blocks))
            gpu_cache.append(layer_cache)
        return gpu_cache

    def allocate_cpu_cache(self) -> List[List[KVCache]]:
        cpu_cache: List[List[KVCache]] = []
        for _ in range(self.num_layers):
            layer_cache: List[KVCache] = []
            for _ in range(self.num_heads):
                key_blocks = torch.empty(
                    (self.num_cpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    pin_memory=True,
                )
                value_blocks = torch.empty(
                    (self.num_cpu_blocks, self.block_size * self.head_size),
                    dtype=self.dtype,
                    pin_memory=True,
                )
                layer_cache.append((key_blocks, value_blocks))
            cpu_cache.append(layer_cache)
        return cpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
    def copy(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
90
            pass