cache_engine.py 4.34 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
4
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13
14
15
16

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
        if head_size % 16 != 0:
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
23
24
            raise ValueError(
                f'head_size ({head_size}) must be a multiple of 16.')
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
        self.worker_id = worker_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
30
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
34
35
36
37
38
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
39
        # Initialize the stream for caching operations.
Zhuohan Li's avatar
Zhuohan Li committed
40
41
        self.cache_stream = torch.cuda.Stream()
        assert self.cache_stream != torch.cuda.current_stream()
Woosuk Kwon's avatar
Woosuk Kwon committed
42
        # Initialize the events for stream synchronization.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
43
        self.events = [torch.cuda.Event() for _ in range(num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
44

Woosuk Kwon's avatar
Woosuk Kwon committed
45
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
54
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
55
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
        return (
            self.num_heads,
            self.head_size,
59
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
64
65
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
67
            key_blocks = torch.empty(
68
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
69
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
70
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
            )
            value_blocks = torch.empty(
73
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
74
                dtype=self.dtype,
Zhuohan Li's avatar
Zhuohan Li committed
75
                device="cuda",
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
82
83
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
85
            key_blocks = torch.empty(
86
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
91
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
        return cpu_cache

98
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
103
104
105
106
107
108
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
109
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
                    src_key_cache, dst_key_cache, src_to_dst)
                # Copy the value blocks.
112
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
                    src_value_cache, dst_value_cache, src_to_dst)
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
117
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
118
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
119

Woosuk Kwon's avatar
Woosuk Kwon committed
120
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
121
122
123
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
124
125
126
127
        key_caches = [key_cache for key_cache, _ in self.gpu_cache]
        value_caches = [value_cache for _, value_cache in self.gpu_cache]
        # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)