cache_engine.py 6.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""CacheEngine class for managing the KV cache."""
3
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import torch
7

8
from vllm import envs
9
from vllm.attention import get_attn_backend
10
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
11
from vllm.logger import init_logger
12
from vllm.platforms import current_platform
13
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
14
15
                        align_to_256bytes, get_dtype_size,
                        is_pin_memory_available)
16
17

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20


class CacheEngine:
21
22
23
24
25
26
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29

    def __init__(
        self,
30
31
32
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
33
        device_config: DeviceConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
    ) -> None:
35
36
37
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
38
        self.device_config = device_config
39
40

        self.head_size = model_config.get_head_size()
Mor Zusman's avatar
Mor Zusman committed
41
        # Models like Jamba, have mixed typed layers, E.g Mamba
42
43
        self.num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
44
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
45
        self.align_cache = self._align_cache(model_config)
46
47
48

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
49
50
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
51
        self.num_cpu_blocks = cache_config.num_cpu_blocks
52
53
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
54

55
56
57
58
59
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

60
        # Get attention backend.
61
62
63
64
        self.attn_backend = get_attn_backend(self.head_size,
                                             model_config.dtype,
                                             cache_config.cache_dtype,
                                             self.block_size,
65
66
                                             model_config.is_attention_free,
                                             use_mla=model_config.use_mla)
67

Woosuk Kwon's avatar
Woosuk Kwon committed
68
        # Initialize the cache.
69
70
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
71
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
72

73
74
75
76
77
78
79
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
80
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
81
82
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        # Align entries so they are 256 byte aligned for better performance
        # Primarily targets MLA as this typically only ends up having entries
        # be 128 byte aligned.
        if self.align_cache:
            # We assume the cache shape is:
            #    (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
            # NOTE this assumption currently only holds for MLA so we only apply
            # this optimization when `use_mla` is true
            entry_shape = kv_cache_shape[2:]
            entry_size = np.prod(entry_shape)
            alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
            alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)
        else:
            alloc_shape = kv_cache_shape

Mor Zusman's avatar
Mor Zusman committed
99
        for _ in range(self.num_attention_layers):
100
101
102
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
103
104
105
106
107
108
109
110
111
112
113
114
115
            layer_kv_cache = torch.zeros(alloc_shape,
                                         dtype=self.dtype,
                                         pin_memory=pin_memory,
                                         device=device)

            # If we allocated with padding for alignment reasons truncate the
            # shape while preserving the aligned stride
            if self.align_cache:
                layer_kv_cache = layer_kv_cache[..., :entry_size]

            # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
            # when entry_shape is higher than 1D
            kv_cache.append(layer_kv_cache.view(kv_cache_shape))
116
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
117

118
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
119
        for i in range(self.num_attention_layers):
120
121
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
122

123
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
124
        for i in range(self.num_attention_layers):
125
126
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
127

128
    def copy(self, src_to_dsts: torch.Tensor) -> None:
129
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
130

131
132
133
134
135
136
137
138
    @staticmethod
    def _align_cache(model_config: ModelConfig):
        # Currently align_cache only applies to MLA models since the other
        # cache kernels haven't been updated yet to support non-continguous
        # tensors
        return model_config.use_mla and current_platform.is_cuda() \
            and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE

139
140
    @staticmethod
    def get_cache_block_size(
141
        cache_config: CacheConfig,
142
143
144
145
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
146
        num_heads = model_config.get_num_kv_heads(parallel_config)
147
148
        num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
149

150
        if cache_config.cache_dtype == "auto":
151
152
            dtype = model_config.dtype
        else:
153
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
154
155
156
157
158
159
160
161
162
163
164
165

        key_cache_entry = num_heads * head_size
        if CacheEngine._align_cache(model_config):
            key_cache_entry = align_to_256bytes(key_cache_entry,
                                                model_config.dtype)

        # For MLA there is no value cache, since the latent vector
        # is joint keys and values.
        value_cache_entry = key_cache_entry if not model_config.use_mla else 0
        total = num_attention_layers * cache_config.block_size * \
            (key_cache_entry + value_cache_entry)

166
        dtype_size = get_dtype_size(dtype)
167
        return dtype_size * total