cache_engine.py 4.04 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
4
from cacheflow import ops
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13
14
15
16
17

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
23
        if head_size % 16 != 0:
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
24
25
            raise ValueError(
                f'head_size ({head_size}) must be a multiple of 16.')
Woosuk Kwon's avatar
Woosuk Kwon committed
26

Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
30
31
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
37
38
39
40
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
45
        self.events = [torch.cuda.Event() for _ in range(num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
46

Woosuk Kwon's avatar
Woosuk Kwon committed
47
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52
53
54
55
56
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
57
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63
64
65
        return (
            self.num_heads,
            self.block_size,
            self.head_size,
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
73
74
75
76
77
            key_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            value_blocks = torch.empty(
                size=(self.num_gpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                device=self.gpu_id,
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86
87
88
89
90
91
92
93
            key_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_key_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
                size=(self.num_cpu_blocks, *self.get_value_block_shape()),
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
        return cpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def _copy_blocks(
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
                ops.copy_cache_blocks(
                    src_key_cache, dst_key_cache, src_to_dst)
                # Copy the value blocks.
                ops.copy_cache_blocks(
                    src_value_cache, dst_value_cache, src_to_dst)
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
115
    def copy(self, src_to_dst: Dict[int, int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
116
        self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
117

Woosuk Kwon's avatar
Woosuk Kwon committed
118
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
120

Woosuk Kwon's avatar
Woosuk Kwon committed
121
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst)