cache_engine.py 4.17 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
2
from typing import List
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from vllm.attention import get_attn_backend
Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
8
from vllm.logger import init_logger
9
10
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
                        is_pin_memory_available)
11
12

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
15


class CacheEngine:
16
17
18
19
20
21
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24

    def __init__(
        self,
25
26
27
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    ) -> None:
29
30
31
32
33
34
        self.cache_config = cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config

        self.head_size = model_config.get_head_size()
        self.num_layers = model_config.get_num_layers(parallel_config)
35
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
36
37
38
39

        self.block_size = cache_config.block_size
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        self.num_cpu_blocks = cache_config.num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41
42
43
44
45
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

46
        # Get attention backend.
47
48
49
50
51
52
53
54
55
        self.attn_backend = get_attn_backend(
            model_config.get_num_attention_heads(parallel_config),
            self.head_size,
            self.num_kv_heads,
            model_config.get_sliding_window(),
            model_config.dtype,
            cache_config.cache_dtype,
            self.block_size,
        )
56

Woosuk Kwon's avatar
Woosuk Kwon committed
57
        # Initialize the cache.
58
59
        self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
60

61
62
63
64
65
66
67
    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
68
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
69
70
71
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
        for _ in range(self.num_layers):
72
73
74
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
75
            kv_cache.append(
76
                torch.zeros(kv_cache_shape,
77
78
79
80
                            dtype=self.dtype,
                            pin_memory=pin_memory,
                            device=device))
        return kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
81

82
    def swap_in(self, src_to_dst: torch.Tensor) -> None:
83
84
85
        for i in range(self.num_layers):
            self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                          src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
86

87
    def swap_out(self, src_to_dst: torch.Tensor) -> None:
88
89
90
        for i in range(self.num_layers):
            self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                          src_to_dst)
91

92
    def copy(self, src_to_dsts: torch.Tensor) -> None:
93
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
94
95
96

    @staticmethod
    def get_cache_block_size(
97
        cache_config: CacheConfig,
98
99
100
101
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
102
        num_heads = model_config.get_num_kv_heads(parallel_config)
103
104
        num_layers = model_config.get_num_layers(parallel_config)

105
        key_cache_block = cache_config.block_size * num_heads * head_size
106
107
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
108
        if cache_config.cache_dtype == "auto":
109
110
            dtype = model_config.dtype
        else:
111
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
112
        dtype_size = get_dtype_size(dtype)
113
        return dtype_size * total