cache_engine.py 4.61 KB
Newer Older
1
"""CacheEngine class for managing the KV cache."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

6
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:
12
13
14
15
16
17
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
22
23
24

    def __init__(
        self,
        worker_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
29
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
30
        if head_size % 16 != 0:
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
31
32
            raise ValueError(
                f'head_size ({head_size}) must be a multiple of 16.')
Woosuk Kwon's avatar
Woosuk Kwon committed
33

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
        self.worker_id = worker_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
45
46
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
47
        # Initialize the stream for caching operations.
Zhuohan Li's avatar
Zhuohan Li committed
48
49
        self.cache_stream = torch.cuda.Stream()
        assert self.cache_stream != torch.cuda.current_stream()
Woosuk Kwon's avatar
Woosuk Kwon committed
50
        # Initialize the events for stream synchronization.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
51
        self.events = [torch.cuda.Event() for _ in range(num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
63
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
        return (
            self.num_heads,
            self.head_size,
67
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
72
73
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
75
            key_blocks = torch.empty(
76
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
77
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
78
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
            )
            value_blocks = torch.empty(
81
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
82
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
83
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
90
91
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
93
            key_blocks = torch.empty(
94
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
99
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
        return cpu_cache

106
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111
112
113
114
115
116
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
117
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
                    src_key_cache, dst_key_cache, src_to_dst)
                # Copy the value blocks.
120
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
                    src_value_cache, dst_value_cache, src_to_dst)
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
125
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
126
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
127

Woosuk Kwon's avatar
Woosuk Kwon committed
128
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
129
130
131
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
132
133
134
135
        key_caches = [key_cache for key_cache, _ in self.gpu_cache]
        value_caches = [value_cache for _, value_cache in self.gpu_cache]
        # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)