"tests/vscode:/vscode.git/clone" did not exist on "3d6032c6dbe3a4949e7d512e52d5a84abcc6c44b"
cache_engine.py 5.38 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from cacheflow import cache_ops
7
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
12

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:
13
14
15
16
17
18
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21

    def __init__(
        self,
22
23
24
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
    ) -> None:
26
27
28
29
30
31
32
33
34
35
36
37
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
        self.num_heads = model_config.get_num_heads(parallel_config)
        self.dtype = model_config.dtype

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        self.num_cpu_blocks = cache_config.num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
43
        # Initialize the stream for caching operations.
Zhuohan Li's avatar
Zhuohan Li committed
44
45
        self.cache_stream = torch.cuda.Stream()
        assert self.cache_stream != torch.cuda.current_stream()
Woosuk Kwon's avatar
Woosuk Kwon committed
46
        # Initialize the events for stream synchronization.
47
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
48

Woosuk Kwon's avatar
Woosuk Kwon committed
49
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
58
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
59
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
        return (
            self.num_heads,
            self.head_size,
63
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
68
69
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
70
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
71
            key_blocks = torch.empty(
72
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
73
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
74
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
            )
            value_blocks = torch.empty(
77
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
78
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
79
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
86
87
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
89
            key_blocks = torch.empty(
90
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
95
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
        return cpu_cache

102
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
106
107
108
109
110
111
112
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
113
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
                    src_key_cache, dst_key_cache, src_to_dst)
                # Copy the value blocks.
116
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
                    src_value_cache, dst_value_cache, src_to_dst)
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
121
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
122
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
123

Woosuk Kwon's avatar
Woosuk Kwon committed
124
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
125
126
127
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
128
129
130
131
        key_caches = [key_cache for key_cache, _ in self.gpu_cache]
        value_caches = [value_cache for _, value_cache in self.gpu_cache]
        # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

    @staticmethod
    def get_cache_block_size(
        block_size: int,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
        num_heads = model_config.get_num_heads(parallel_config)
        num_layers = model_config.get_num_layers(parallel_config)

        key_cache_block = block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        dtype_size = _get_dtype_size(model_config.dtype)
        return dtype_size * total


def _get_dtype_size(dtype: torch.dtype) -> int:
    return torch.tensor([], dtype=dtype).element_size()