cache_engine.py 4.96 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
2
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from vllm.attention import get_attn_backend
7
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
8
from vllm.logger import init_logger
9
10
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
                        get_dtype_size, is_pin_memory_available)
11
12

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
15


class CacheEngine:
16
17
18
19
20
21
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24

    def __init__(
        self,
25
26
27
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
28
        device_config: DeviceConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
29
    ) -> None:
30
31
32
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
33
        self.device_config = device_config
34
35

        self.head_size = model_config.get_head_size()
Mor Zusman's avatar
Mor Zusman committed
36
        # Models like Jamba, have mixed typed layers, E.g Mamba
37
38
        self.num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
39
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
40
41
42

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
43
44
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
45
        self.num_cpu_blocks = cache_config.num_cpu_blocks
46
47
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49
50
51
52
53
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

54
        # Get attention backend.
55
56
57
58
        self.attn_backend = get_attn_backend(self.head_size,
                                             model_config.dtype,
                                             cache_config.cache_dtype,
                                             self.block_size,
59
60
                                             model_config.is_attention_free,
                                             use_mla=model_config.use_mla)
61

Woosuk Kwon's avatar
Woosuk Kwon committed
62
        # Initialize the cache.
63
64
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
65
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
66

67
68
69
70
71
72
73
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
74
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
75
76
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
Mor Zusman's avatar
Mor Zusman committed
77
        for _ in range(self.num_attention_layers):
78
79
80
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
81
            kv_cache.append(
82
                torch.zeros(kv_cache_shape,
83
84
85
86
                            dtype=self.dtype,
                            pin_memory=pin_memory,
                            device=device))
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
87

88
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
89
        for i in range(self.num_attention_layers):
90
91
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
92

93
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
Mor Zusman's avatar
Mor Zusman committed
94
        for i in range(self.num_attention_layers):
95
96
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
97

98
    def copy(self, src_to_dsts: torch.Tensor) -> None:
99
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
100
101
102

    @staticmethod
    def get_cache_block_size(
103
        cache_config: CacheConfig,
104
105
106
107
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        num_heads = model_config.get_num_kv_heads(parallel_config)
109
110
        num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
111

112
        key_cache_block = cache_config.block_size * num_heads * head_size
113
114
115
        # For MLA there is no value cache, since the latent vector
        # is joint keys and values.
        value_cache_block = key_cache_block if not model_config.use_mla else 0
Mor Zusman's avatar
Mor Zusman committed
116
        total = num_attention_layers * (key_cache_block + value_cache_block)
117
        if cache_config.cache_dtype == "auto":
118
119
            dtype = model_config.dtype
        else:
120
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
121
        dtype_size = get_dtype_size(dtype)
122
        return dtype_size * total