cache_engine.py 5.84 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
8
9
10
11
from vllm.logger import init_logger
from vllm.utils import in_wsl

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
15
16

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:
17
18
19
20
21
22
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25

    def __init__(
        self,
26
27
28
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
29
    ) -> None:
30
31
32
33
34
35
36
37
38
39
40
41
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
        self.num_heads = model_config.get_num_heads(parallel_config)
        self.dtype = model_config.dtype

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        self.num_cpu_blocks = cache_config.num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
45
46

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
47
        # Initialize the stream for caching operations.
Zhuohan Li's avatar
Zhuohan Li committed
48
49
        self.cache_stream = torch.cuda.Stream()
        assert self.cache_stream != torch.cuda.current_stream()
Woosuk Kwon's avatar
Woosuk Kwon committed
50
        # Initialize the events for stream synchronization.
51
        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
63
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
        return (
            self.num_heads,
            self.head_size,
67
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
72
73
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
75
            key_blocks = torch.empty(
76
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
77
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
78
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
            )
            value_blocks = torch.empty(
81
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
82
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
83
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
90
91
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
92
93
94
95
96
97
        pin_memory = not in_wsl()
        if not pin_memory:
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warn("Using 'pin_memory=False' as WSL is detected. "
                        "This may slow down the performance.")
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
99
            key_blocks = torch.empty(
100
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
101
                dtype=self.dtype,
102
                pin_memory=pin_memory,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
            )
            value_blocks = torch.empty(
105
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
106
                dtype=self.dtype,
107
                pin_memory=pin_memory,
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
        return cpu_cache

112
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
118
119
120
121
122
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
123
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
                    src_key_cache, dst_key_cache, src_to_dst)
                # Copy the value blocks.
126
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
129
130
                    src_value_cache, dst_value_cache, src_to_dst)
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
131
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
132
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
133

Woosuk Kwon's avatar
Woosuk Kwon committed
134
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
135
136
137
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
138
139
140
141
        key_caches = [key_cache for key_cache, _ in self.gpu_cache]
        value_caches = [value_cache for _, value_cache in self.gpu_cache]
        # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    @staticmethod
    def get_cache_block_size(
        block_size: int,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
        num_heads = model_config.get_num_heads(parallel_config)
        num_layers = model_config.get_num_layers(parallel_config)

        key_cache_block = block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        dtype_size = _get_dtype_size(model_config.dtype)
        return dtype_size * total


def _get_dtype_size(dtype: torch.dtype) -> int:
    return torch.tensor([], dtype=dtype).element_size()