cache_engine.py 6.4 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
7
from vllm.logger import init_logger
8
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
9
10

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:
16
17
18
19
20
21
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24

    def __init__(
        self,
25
26
27
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    ) -> None:
29
30
31
32
33
34
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
        self.num_heads = model_config.get_num_kv_heads(parallel_config)
36
37
38
39

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        self.num_cpu_blocks = cache_config.num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41
42
43
44
        # Skip initializing CUDA stream and buffer for Neuron backend.
        if is_neuron():
            return

45
46
47
48
49
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
54
        # Initialize the stream for caching operations.
Zhuohan Li's avatar
Zhuohan Li committed
55
56
        self.cache_stream = torch.cuda.Stream()
        assert self.cache_stream != torch.cuda.current_stream()
Woosuk Kwon's avatar
Woosuk Kwon committed
57
        # Initialize the events for stream synchronization.
58
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
65
66
67
68
69
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
70
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
        return (
            self.num_heads,
            self.head_size,
74
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
79
80
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
82
            key_blocks = torch.empty(
83
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
84
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
85
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
            )
            value_blocks = torch.empty(
88
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
89
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
90
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
97
98
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
99
100
101
102
        pin_memory = not in_wsl()
        if not pin_memory:
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
103
104
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
106
            key_blocks = torch.empty(
107
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
108
                dtype=self.dtype,
109
                pin_memory=pin_memory,
110
                device="cpu",
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
            )
            value_blocks = torch.empty(
113
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
114
                dtype=self.dtype,
115
                pin_memory=pin_memory,
116
                device="cpu",
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
        return cpu_cache

121
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
125
126
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
127
128
        from vllm._C import cache_ops

Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
134
                cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
135
                # Copy the value blocks.
136
137
                cache_ops.swap_blocks(src_value_cache, dst_value_cache,
                                      src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
141
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
142
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
143

Woosuk Kwon's avatar
Woosuk Kwon committed
144
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
145
146
147
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
148
149
        from vllm._C import cache_ops

150
151
152
153
        key_caches = [key_cache for key_cache, _ in self.gpu_cache]
        value_caches = [value_cache for _, value_cache in self.gpu_cache]
        # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
154
155
156
157

    @staticmethod
    def get_cache_block_size(
        block_size: int,
158
        cache_dtype: str,
159
160
161
162
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
163
        num_heads = model_config.get_num_kv_heads(parallel_config)
164
165
166
167
168
        num_layers = model_config.get_num_layers(parallel_config)

        key_cache_block = block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
169
170
171
172
173
        if cache_dtype == "auto":
            dtype = model_config.dtype
        else:
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
        dtype_size = _get_dtype_size(dtype)
174
175
176
177
178
        return dtype_size * total


def _get_dtype_size(dtype: torch.dtype) -> int:
    return torch.tensor([], dtype=dtype).element_size()