kvcache_pool.py 3.78 KB
Newer Older
Pan Zezhong's avatar
Pan Zezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
import asyncio
from typing import List


class KVCachePoolItem:
    def __init__(self, model):
        self.kvcache = model.create_kv_cache()
        self.tokens = [0 for _ in range(model.max_context_len())]

    def drop(self, model):
        model.drop_kv_cache(self.kvcache)

Pan Zezhong's avatar
Pan Zezhong committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    def update_tokens(self, tokens, pos):
        end = pos + len(tokens)
        max_len = len(self.tokens)

        # If overflow, truncate tokens to fit
        if end > max_len:
            tokens = tokens[: max_len - pos]
            end = max_len

        self.tokens[pos:end] = tokens


import threading

Pan Zezhong's avatar
Pan Zezhong committed
27
28
29
30
31

class KVCachePool:
    def __init__(self, model, max_caches: int = 32):
        self.max_caches = max_caches
        self.model = model
Pan Zezhong's avatar
Pan Zezhong committed
32
33
34
35
        self._available: List[KVCachePoolItem] = []
        self.num_caches = len(self._available)
        self._lock = threading.Lock()
        self._not_empty = threading.Condition(self._lock)
Pan Zezhong's avatar
Pan Zezhong committed
36
37
        self._shutdown = False

Pan Zezhong's avatar
Pan Zezhong committed
38
39
    def acquire_sync(self, infer_task):
        with self._not_empty:
Pan Zezhong's avatar
Pan Zezhong committed
40
41
            while True:
                if self._shutdown:
Pan Zezhong's avatar
Pan Zezhong committed
42
43
44
                    raise RuntimeError(
                        "KVCachePool is shutting down; cannot acquire new cache."
                    )
Pan Zezhong's avatar
Pan Zezhong committed
45
46
47
                if len(self._available) == 0:
                    if self.num_caches < self.max_caches:
                        self.num_caches += 1
Pan Zezhong's avatar
Pan Zezhong committed
48
                        print(f"[INFO] Task {infer_task.id} created new KVCachePoolItem")
Pan Zezhong's avatar
Pan Zezhong committed
49
50
                        return infer_task.bind_kvcache(KVCachePoolItem(self.model), 0)
                    else:
Pan Zezhong's avatar
Pan Zezhong committed
51
                        self._not_empty.wait()
Pan Zezhong's avatar
Pan Zezhong committed
52
53
54
55
56
                else:
                    max_match, max_match_index = self.find_most_matching_cache(
                        infer_task.tokens
                    )
                    kvcache = self._available.pop(max_match_index)
Pan Zezhong's avatar
Pan Zezhong committed
57
58
59
                    print(
                        f"[INFO] Task {infer_task.id} reused KVCachePoolItem {max_match_index} with {max_match} matches"
                    )
Pan Zezhong's avatar
Pan Zezhong committed
60
61
                    return infer_task.bind_kvcache(kvcache, max_match)

Pan Zezhong's avatar
Pan Zezhong committed
62
63
64
    def release_sync(self, infer_task):
        with self._not_empty:
            print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool")
Pan Zezhong's avatar
Pan Zezhong committed
65
            self._available.append(infer_task._kv_cache_pool_item)
Pan Zezhong's avatar
Pan Zezhong committed
66
            infer_task._kv_cache_pool_item = None
Pan Zezhong's avatar
Pan Zezhong committed
67
68
            self._not_empty.notify()

Pan Zezhong's avatar
Pan Zezhong committed
69
70
71
72
73
74
75
76
    async def acquire(self, infer_task):
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.acquire_sync, infer_task)

    async def release(self, infer_task):
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.release_sync, infer_task)

Pan Zezhong's avatar
Pan Zezhong committed
77
78
79
    def find_most_matching_cache(self, tokens: List[int]):
        max_match = 0
        max_match_index = 0
Pan Zezhong's avatar
Pan Zezhong committed
80

Pan Zezhong's avatar
Pan Zezhong committed
81
82
83
84
85
        def first_different_index(a_, b_):
            for i_, (x_, y_) in enumerate(zip(a_, b_)):
                if x_ != y_:
                    return i_
            return min(len(a_), len(b_))
Pan Zezhong's avatar
Pan Zezhong committed
86
87

        for i, kvcache in enumerate(self._available):
Pan Zezhong's avatar
Pan Zezhong committed
88
            common_elements = first_different_index(tokens, kvcache.tokens)
Pan Zezhong's avatar
Pan Zezhong committed
89
90
            # print(f"{tokens}")
            # print(f"{kvcache.tokens[:len(tokens)]}")
Pan Zezhong's avatar
Pan Zezhong committed
91
92
93
94
95
96
            if common_elements > max_match:
                max_match = common_elements
                max_match_index = i

        return (min(max_match, len(tokens) - 1), max_match_index)

Pan Zezhong's avatar
Pan Zezhong committed
97
98
    def finalize(self):
        with self._not_empty:
Pan Zezhong's avatar
Pan Zezhong committed
99
100
            self._shutdown = True
            while len(self._available) < self.num_caches:
Pan Zezhong's avatar
Pan Zezhong committed
101
                self._not_empty.wait()
Pan Zezhong's avatar
Pan Zezhong committed
102
103
104
105
106
107
108
109

            for kvcache in self._available:
                if kvcache is not None:
                    kvcache.drop(self.model)

            self._available.clear()
            self.max_caches = 0
            self.num_caches = 0