cache_engine.py 3.13 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import torch

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
    ) -> None:
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
27
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
41

Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def get_key_block_shape(self) -> Tuple[int, int, int, int, int]:
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

    def get_value_block_shape(self) -> Tuple[int, int, int, int]:
        return (
            self.num_heads,
            self.block_size,
            self.head_size,
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
66
67
68
69
70
71
72
            key_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            value_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
82
83
84
85
86
87
88
            key_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
        return cpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
    def copy(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
            pass

Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        for event in self.events:
Woosuk Kwon's avatar
Woosuk Kwon committed
101
            pass