cache_engine.py 4.08 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
2
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from vllm.attention import get_attn_backend
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
8
from vllm.logger import init_logger
9
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
10
11

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14


class CacheEngine:
15
16
17
18
19
20
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23

    def __init__(
        self,
24
25
26
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    ) -> None:
28
29
30
31
32
33
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
34
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
35
36
37
38

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        self.num_cpu_blocks = cache_config.num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
42
43
44
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

45
        # Get attention backend.
46
47
48
49
50
51
52
53
54
        self.attn_backend = get_attn_backend(
            model_config.get_num_attention_heads(parallel_config),
            self.head_size,
            self.num_kv_heads,
            model_config.get_sliding_window(),
            model_config.dtype,
            cache_config.cache_dtype,
            self.block_size,
        )
55

Woosuk Kwon's avatar
Woosuk Kwon committed
56
        # Initialize the cache.
57
58
        self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
59

60
61
62
63
64
65
66
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
67
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
68
69
70
71
72
73
74
75
76
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
        for _ in range(self.num_layers):
            kv_cache.append(
                torch.empty(kv_cache_shape,
                            dtype=self.dtype,
                            pin_memory=pin_memory,
                            device=device))
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
79
80
81
        for i in range(self.num_layers):
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
82

83
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
84
85
86
        for i in range(self.num_layers):
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
87

88
    def copy(self, src_to_dsts: torch.Tensor) -> None:
89
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
90
91
92

    @staticmethod
    def get_cache_block_size(
93
        cache_config: CacheConfig,
94
95
96
97
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        num_heads = model_config.get_num_kv_heads(parallel_config)
99
100
        num_layers = model_config.get_num_layers(parallel_config)

101
        key_cache_block = cache_config.block_size * num_heads * head_size
102
103
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
104
        if cache_config.cache_dtype == "auto":
105
106
            dtype = model_config.dtype
        else:
107
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
108
        dtype_size = _get_dtype_size(dtype)
109
110
111
112
113
        return dtype_size * total


def _get_dtype_size(dtype: torch.dtype) -> int:
    return torch.tensor([], dtype=dtype).element_size()