cache_engine.py 5.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""CacheEngine class for managing the KV cache."""
3
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6

7
from vllm.attention import get_attn_backend
8
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
9
from vllm.logger import init_logger
10
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
11
                        get_dtype_size, is_pin_memory_available)
12
13

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16


class CacheEngine:
17
18
19
20
21
22
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25

    def __init__(
        self,
26
27
28
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
29
        device_config: DeviceConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
30
    ) -> None:
31
32
33
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
34
        self.device_config = device_config
35
36

        self.head_size = model_config.get_head_size()
Mor Zusman's avatar
Mor Zusman committed
37
        # Models like Jamba, have mixed typed layers, E.g Mamba
38
39
        self.num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
40
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
41
42
43

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
44
45
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
46
        self.num_cpu_blocks = cache_config.num_cpu_blocks
47
48
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50
51
52
53
54
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

55
        # Get attention backend.
56
57
58
59
        self.attn_backend = get_attn_backend(self.head_size,
                                             model_config.dtype,
                                             cache_config.cache_dtype,
                                             self.block_size,
60
61
                                             model_config.is_attention_free,
                                             use_mla=model_config.use_mla)
62

Woosuk Kwon's avatar
Woosuk Kwon committed
63
        # Initialize the cache.
64
65
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
66
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
67

68
69
70
71
72
73
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
74
        kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
75
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
76
77
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
78
79
80
81
82
83
84
85
86
87
88
89
        try:
            kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
            )
        except (AttributeError, NotImplementedError):
            kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))

        # The allocation respects the backend-defined stride order to ensure
        # the semantic remains consistent for each backend. We first obtain the
        # generic kv cache shape and then permute it according to the stride
        # order which could result in a non-contiguous tensor.
        kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
                                          for i in kv_cache_stride_order)
90

Mor Zusman's avatar
Mor Zusman committed
91
        for _ in range(self.num_attention_layers):
92
93
94
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
95
96
97
98
99
            layer_kv_cache = torch.zeros(
                kv_cache_allocation_shape,
                dtype=self.dtype,
                pin_memory=pin_memory,
                device=device).permute(*kv_cache_stride_order)
100
101
102

            # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
            # when entry_shape is higher than 1D
103
            kv_cache.append(layer_kv_cache)
104
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
105

106
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
107
        for i in range(self.num_attention_layers):
108
109
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
110

111
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
112
        for i in range(self.num_attention_layers):
113
114
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
115

116
    def copy(self, src_to_dsts: torch.Tensor) -> None:
117
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
118
119
120

    @staticmethod
    def get_cache_block_size(
121
        cache_config: CacheConfig,
122
123
124
125
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
126
        num_heads = model_config.get_num_kv_heads(parallel_config)
127
128
        num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
129

130
        if cache_config.cache_dtype == "auto":
131
132
            dtype = model_config.dtype
        else:
133
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
134
135
136
137
138
139
140
141
142

        key_cache_entry = num_heads * head_size

        # For MLA there is no value cache, since the latent vector
        # is joint keys and values.
        value_cache_entry = key_cache_entry if not model_config.use_mla else 0
        total = num_attention_layers * cache_config.block_size * \
            (key_cache_entry + value_cache_entry)

143
        dtype_size = get_dtype_size(dtype)
144
        return dtype_size * total