cache_engine.py 5.88 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
7
from vllm.logger import init_logger
8
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
9
10

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:
16
17
18
19
20
21
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24

    def __init__(
        self,
25
26
27
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    ) -> None:
29
30
31
32
33
34
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
        self.num_heads = model_config.get_num_kv_heads(parallel_config)
36
37
38
39

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        self.num_cpu_blocks = cache_config.num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41
        # Skip initializing KV cache for Neuron backend.
42
43
44
        if is_neuron():
            return

45
46
47
48
49
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
54
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
64
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
        return (
            self.num_heads,
            self.head_size,
68
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
73
74
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
76
            key_blocks = torch.empty(
77
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
78
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
79
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
            )
            value_blocks = torch.empty(
82
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
83
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
84
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
91
92
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
93
94
95
96
        pin_memory = not in_wsl()
        if not pin_memory:
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
97
98
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
100
            key_blocks = torch.empty(
101
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
102
                dtype=self.dtype,
103
                pin_memory=pin_memory,
104
                device="cpu",
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
            )
            value_blocks = torch.empty(
107
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
108
                dtype=self.dtype,
109
                pin_memory=pin_memory,
110
                device="cpu",
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
        return cpu_cache

115
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
119
120
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
121
122
        from vllm._C import cache_ops

123
124
125
126
127
128
129
        for i in range(self.num_layers):
            src_key_cache, src_value_cache = src[i]
            dst_key_cache, dst_value_cache = dst[i]
            # Copy the key blocks.
            cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
            # Copy the value blocks.
            cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
130

Woosuk Kwon's avatar
Woosuk Kwon committed
131
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
132
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
133

Woosuk Kwon's avatar
Woosuk Kwon committed
134
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
135
136
137
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
138
139
        from vllm._C import cache_ops

140
141
142
143
        key_caches = [key_cache for key_cache, _ in self.gpu_cache]
        value_caches = [value_cache for _, value_cache in self.gpu_cache]
        # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
144
145
146
147

    @staticmethod
    def get_cache_block_size(
        block_size: int,
148
        cache_dtype: str,
149
150
151
152
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
153
        num_heads = model_config.get_num_kv_heads(parallel_config)
154
155
156
157
158
        num_layers = model_config.get_num_layers(parallel_config)

        key_cache_block = block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
159
160
161
162
163
        if cache_dtype == "auto":
            dtype = model_config.dtype
        else:
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
        dtype_size = _get_dtype_size(dtype)
164
165
166
167
168
        return dtype_size * total


def _get_dtype_size(dtype: torch.dtype) -> int:
    return torch.tensor([], dtype=dtype).element_size()