cache_engine.py 5.93 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""CacheEngine class for managing the KV cache."""
4
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7

8
from vllm.attention import get_attn_backend
9
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
10
from vllm.logger import init_logger
11
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
12
                        get_dtype_size, is_pin_memory_available)
13
14

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17


class CacheEngine:
18
19
20
21
22
23
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26

    def __init__(
        self,
27
28
29
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
30
        device_config: DeviceConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
    ) -> None:
32
33
34
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
35
        self.device_config = device_config
36
37

        self.head_size = model_config.get_head_size()
Mor Zusman's avatar
Mor Zusman committed
38
        # Models like Jamba, have mixed typed layers, E.g Mamba
39
40
        self.num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
41
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
42
43
44

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
45
46
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
47
        self.num_cpu_blocks = cache_config.num_cpu_blocks
48
49
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
50

51
52
53
54
55
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

56
        # Get attention backend.
57
58
59
60
        self.attn_backend = get_attn_backend(self.head_size,
                                             model_config.dtype,
                                             cache_config.cache_dtype,
                                             self.block_size,
61
62
                                             model_config.is_attention_free,
                                             use_mla=model_config.use_mla)
63

Woosuk Kwon's avatar
Woosuk Kwon committed
64
        # Initialize the cache.
65
66
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
67
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
68

69
70
71
72
73
74
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
75
        kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
76
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
77
78
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
79
80
81
82
83
84
85
86
87
88
89
90
        try:
            kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
            )
        except (AttributeError, NotImplementedError):
            kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))

        # The allocation respects the backend-defined stride order to ensure
        # the semantic remains consistent for each backend. We first obtain the
        # generic kv cache shape and then permute it according to the stride
        # order which could result in a non-contiguous tensor.
        kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
                                          for i in kv_cache_stride_order)
91

Mor Zusman's avatar
Mor Zusman committed
92
        for _ in range(self.num_attention_layers):
93
94
95
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
96
97
98
99
100
            layer_kv_cache = torch.zeros(
                kv_cache_allocation_shape,
                dtype=self.dtype,
                pin_memory=pin_memory,
                device=device).permute(*kv_cache_stride_order)
101
102
103

            # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
            # when entry_shape is higher than 1D
104
            kv_cache.append(layer_kv_cache)
105
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
106

107
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
108
        for i in range(self.num_attention_layers):
109
110
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
111

112
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
113
        for i in range(self.num_attention_layers):
114
115
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
116

117
    def copy(self, src_to_dsts: torch.Tensor) -> None:
118
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
119
120
121

    @staticmethod
    def get_cache_block_size(
122
        cache_config: CacheConfig,
123
124
125
126
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
127
        num_heads = model_config.get_num_kv_heads(parallel_config)
128
129
        num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
130

131
        if cache_config.cache_dtype == "auto":
132
133
            dtype = model_config.dtype
        else:
134
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
135
136
137
138
139
140
141
142
143

        key_cache_entry = num_heads * head_size

        # For MLA there is no value cache, since the latent vector
        # is joint keys and values.
        value_cache_entry = key_cache_entry if not model_config.use_mla else 0
        total = num_attention_layers * cache_config.block_size * \
            (key_cache_entry + value_cache_entry)

144
        dtype_size = get_dtype_size(dtype)
145
        return dtype_size * total