cache_engine.py 5.36 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
2
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from vllm.attention import get_attn_backend
7
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
8
from vllm.logger import init_logger
9
10
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
                        get_dtype_size, is_pin_memory_available)
11
from vllm.attention.backends.tree_decoding_utils import move_cache
12
13

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16


class CacheEngine:
17
18
19
20
21
22
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25

    def __init__(
        self,
26
27
28
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
29
        device_config: DeviceConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
30
    ) -> None:
31
32
33
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
34
        self.device_config = device_config
35
36

        self.head_size = model_config.get_head_size()
Mor Zusman's avatar
Mor Zusman committed
37
        # Models like Jamba, have mixed typed layers, E.g Mamba
38
39
        self.num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
40
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
41
42
43

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
44
45
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
46
        self.num_cpu_blocks = cache_config.num_cpu_blocks
47
48
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50
51
52
53
54
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

55
        # Get attention backend.
56
57
58
59
        self.attn_backend = get_attn_backend(self.head_size,
                                             model_config.dtype,
                                             cache_config.cache_dtype,
                                             self.block_size,
60
61
                                             model_config.is_attention_free,
                                             use_mla=model_config.use_mla)
62

Woosuk Kwon's avatar
Woosuk Kwon committed
63
        # Initialize the cache.
64
65
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
66
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
67

68
69
70
71
72
73
74
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
75
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
76
77
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
Mor Zusman's avatar
Mor Zusman committed
78
        for _ in range(self.num_attention_layers):
79
80
81
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
82
            kv_cache.append(
83
                torch.zeros(kv_cache_shape,
84
85
86
87
                            dtype=self.dtype,
                            pin_memory=pin_memory,
                            device=device))
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
88

89
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
90
        for i in range(self.num_attention_layers):
91
92
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
93

94
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
95
        for i in range(self.num_attention_layers):
96
97
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
98

99
    def copy(self, src_to_dsts: torch.Tensor) -> None:
100
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
101

102
103
    def move_caches(self, kv_caches: List[torch.Tensor],
                    src_to_dsts: torch.Tensor) -> None:
104
105
106
107
108
109
        move_cache(self.attn_backend,
                   kv_caches,
                   src_to_dsts,
                   self.cache_config.cache_dtype,
                   self.num_kv_heads,
                   self.head_size)
110

111
112
    @staticmethod
    def get_cache_block_size(
113
        cache_config: CacheConfig,
114
115
116
117
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        num_heads = model_config.get_num_kv_heads(parallel_config)
119
120
        num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
121

122
        key_cache_block = cache_config.block_size * num_heads * head_size
123
124
125
        # For MLA there is no value cache, since the latent vector
        # is joint keys and values.
        value_cache_block = key_cache_block if not model_config.use_mla else 0
Mor Zusman's avatar
Mor Zusman committed
126
        total = num_attention_layers * (key_cache_block + value_cache_block)
127
        if cache_config.cache_dtype == "auto":
128
129
            dtype = model_config.dtype
        else:
130
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
131
        dtype_size = get_dtype_size(dtype)
132
        return dtype_size * total