cache_engine.py 4.92 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Dict, List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
4
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13
14
15
16
17

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CacheEngine:

    def __init__(
        self,
        worker_id: int,
        gpu_id: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
        block_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
        num_gpu_blocks: int,
        num_cpu_blocks: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
        dtype: torch.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
23
        if head_size % 16 != 0:
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
24
25
            raise ValueError(
                f'head_size ({head_size}) must be a multiple of 16.')
Woosuk Kwon's avatar
Woosuk Kwon committed
26

Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
30
31
        self.worker_id = worker_id
        self.gpu_id = gpu_id
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        self.block_size = block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
37
38
39
40
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks
        self.dtype = dtype

        # Initialize the cache.
        self.gpu_cache = self.allocate_gpu_cache()
        self.cpu_cache = self.allocate_cpu_cache()

Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44
        # Initialize the stream for caching operations.
        self.cache_stream = torch.cuda.Stream(device=gpu_id)
        assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
        # Initialize the events for stream synchronization.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
45
        self.events = [torch.cuda.Event() for _ in range(num_layers)]
Woosuk Kwon's avatar
Woosuk Kwon committed
46

Woosuk Kwon's avatar
Woosuk Kwon committed
47
    def get_key_block_shape(self) -> Tuple[int, int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52
53
54
55
56
        element_size = torch.tensor([], dtype=self.dtype).element_size()
        x = 16 // element_size
        return (
            self.num_heads,
            self.head_size // x,
            self.block_size,
            x,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
57
    def get_value_block_shape(self) -> Tuple[int, int, int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
        return (
            self.num_heads,
            self.head_size,
61
            self.block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
        )

    def allocate_gpu_cache(self) -> List[KVCache]:
        gpu_cache: List[KVCache] = []
66
67
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
69
            key_blocks = torch.empty(
70
                size=(self.num_gpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
                dtype=self.dtype,
                device=self.gpu_id,
            )
            value_blocks = torch.empty(
75
                size=(self.num_gpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
                dtype=self.dtype,
                device=self.gpu_id,
            )
            gpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
        return gpu_cache

Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
    def allocate_cpu_cache(self) -> List[KVCache]:
        cpu_cache: List[KVCache] = []
84
85
        key_block_shape = self.get_key_block_shape()
        value_block_shape = self.get_value_block_shape()
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        for _ in range(self.num_layers):
Woosuk Kwon's avatar
Woosuk Kwon committed
87
            key_blocks = torch.empty(
88
                size=(self.num_cpu_blocks, *key_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
                dtype=self.dtype,
                pin_memory=True,
            )
            value_blocks = torch.empty(
93
                size=(self.num_cpu_blocks, *value_block_shape),
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
                dtype=self.dtype,
                pin_memory=True,
            )
            cpu_cache.append((key_blocks, value_blocks))
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
        return cpu_cache

100
    def _swap(
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
107
108
109
110
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dst: Dict[int, int],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
111
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
                    src_key_cache, dst_key_cache, src_to_dst)
                # Copy the value blocks.
114
                cache_ops.swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
118
                    src_value_cache, dst_value_cache, src_to_dst)
                event = self.events[i]
                event.record(stream=self.cache_stream)

Woosuk Kwon's avatar
Woosuk Kwon committed
119
    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
120
        self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Woosuk Kwon's avatar
Woosuk Kwon committed
121

Woosuk Kwon's avatar
Woosuk Kwon committed
122
    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

    def _copy(
        self,
        src: List[KVCache],
        dst: List[KVCache],
        src_to_dsts: Dict[int, List[int]],
    ) -> None:
        with torch.cuda.stream(self.cache_stream):
            for i in range(self.num_layers):
                src_key_cache, src_value_cache = src[i]
                dst_key_cache, dst_value_cache = dst[i]
                # Copy the key blocks.
                cache_ops.copy_blocks(
                    src_key_cache, dst_key_cache, src_to_dsts)
                # Copy the value blocks.
                cache_ops.copy_blocks(
                    src_value_cache, dst_value_cache, src_to_dsts)
                event = self.events[i]
                event.record(stream=self.cache_stream)

    def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
        self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)