"vscode:/vscode.git/clone" did not exist on "003f8ee1287f90a7e8aa9b9e7d6246ac74ebefbe"
cache_engine.py 5.07 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
2
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from vllm.attention import get_attn_backend
7
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
8
from vllm.logger import init_logger
9
10
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
                        is_pin_memory_available)
11
from vllm.attention.backends.tree_decoding_utils import move_cache
12
13

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16


class CacheEngine:
17
18
19
20
21
22
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25

    def __init__(
        self,
26
27
28
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
29
        device_config: DeviceConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
30
    ) -> None:
31
32
33
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
34
        self.device_config = device_config
35
36

        self.head_size = model_config.get_head_size()
Mor Zusman's avatar
Mor Zusman committed
37
38
39
        # Models like Jamba, have mixed typed layers, E.g Mamba
        self.num_attention_layers = model_config.get_num_attention_layers(
            parallel_config)
40
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
41
42
43

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
44
45
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
46
        self.num_cpu_blocks = cache_config.num_cpu_blocks
47
48
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50
51
52
53
54
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

55
        # Get attention backend.
56
57
58
59
60
61
62
63
64
        self.attn_backend = get_attn_backend(
            model_config.get_num_attention_heads(parallel_config),
            self.head_size,
            self.num_kv_heads,
            model_config.get_sliding_window(),
            model_config.dtype,
            cache_config.cache_dtype,
            self.block_size,
        )
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
        # Initialize the cache.
67
68
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
69
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
70

71
72
73
74
75
76
77
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
78
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
79
80
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
Mor Zusman's avatar
Mor Zusman committed
81
        for _ in range(self.num_attention_layers):
82
83
84
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
85
            kv_cache.append(
86
                torch.zeros(kv_cache_shape,
87
88
89
90
                            dtype=self.dtype,
                            pin_memory=pin_memory,
                            device=device))
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
91

92
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
93
        for i in range(self.num_attention_layers):
94
95
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
96

97
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
98
        for i in range(self.num_attention_layers):
99
100
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
101

102
    def copy(self, src_to_dsts: torch.Tensor) -> None:
103
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
104

105
106
    def move_caches(self, kv_caches: List[torch.Tensor],
                    src_to_dsts: torch.Tensor) -> None:
107
108
109
110
111
112
        move_cache(self.attn_backend,
                   kv_caches,
                   src_to_dsts,
                   self.cache_config.cache_dtype,
                   self.num_kv_heads,
                   self.head_size)
113

114
115
    @staticmethod
    def get_cache_block_size(
116
        cache_config: CacheConfig,
117
118
119
120
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
121
        num_heads = model_config.get_num_kv_heads(parallel_config)
Mor Zusman's avatar
Mor Zusman committed
122
123
        num_attention_layers = model_config.get_num_attention_layers(
            parallel_config)
124

125
        key_cache_block = cache_config.block_size * num_heads * head_size
126
        value_cache_block = key_cache_block
Mor Zusman's avatar
Mor Zusman committed
127
        total = num_attention_layers * (key_cache_block + value_cache_block)
128
        if cache_config.cache_dtype == "auto":
129
130
            dtype = model_config.dtype
        else:
131
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
132
        dtype_size = get_dtype_size(dtype)
133
        return dtype_size * total