kvcache_pool.py 3.27 KB
Newer Older
1
2
from infer_task import KVCache

Pan Zezhong's avatar
Pan Zezhong committed
3
4
import asyncio
from typing import List
Pan Zezhong's avatar
Pan Zezhong committed
5
6
import threading

Pan Zezhong's avatar
Pan Zezhong committed
7
8
9
10
11

class KVCachePool:
    def __init__(self, model, max_caches: int = 32):
        self.max_caches = max_caches
        self.model = model
12
        self._available: List[KVCache] = []
Pan Zezhong's avatar
Pan Zezhong committed
13
14
15
        self.num_caches = len(self._available)
        self._lock = threading.Lock()
        self._not_empty = threading.Condition(self._lock)
Pan Zezhong's avatar
Pan Zezhong committed
16
17
        self._shutdown = False

Pan Zezhong's avatar
Pan Zezhong committed
18
19
    def acquire_sync(self, infer_task):
        with self._not_empty:
Pan Zezhong's avatar
Pan Zezhong committed
20
21
            while True:
                if self._shutdown:
Pan Zezhong's avatar
Pan Zezhong committed
22
23
24
                    raise RuntimeError(
                        "KVCachePool is shutting down; cannot acquire new cache."
                    )
Pan Zezhong's avatar
Pan Zezhong committed
25
26
27
                if len(self._available) == 0:
                    if self.num_caches < self.max_caches:
                        self.num_caches += 1
28
29
30
31
                        print(
                            f"[INFO] Task {infer_task.id} created new KVCachePoolItem"
                        )
                        return infer_task.bind_kvcache(KVCache(self.model), 0)
Pan Zezhong's avatar
Pan Zezhong committed
32
                    else:
Pan Zezhong's avatar
Pan Zezhong committed
33
                        self._not_empty.wait()
Pan Zezhong's avatar
Pan Zezhong committed
34
35
36
37
38
                else:
                    max_match, max_match_index = self.find_most_matching_cache(
                        infer_task.tokens
                    )
                    kvcache = self._available.pop(max_match_index)
Pan Zezhong's avatar
Pan Zezhong committed
39
40
41
                    print(
                        f"[INFO] Task {infer_task.id} reused KVCachePoolItem {max_match_index} with {max_match} matches"
                    )
Pan Zezhong's avatar
Pan Zezhong committed
42
43
                    return infer_task.bind_kvcache(kvcache, max_match)

Pan Zezhong's avatar
Pan Zezhong committed
44
45
46
    def release_sync(self, infer_task):
        with self._not_empty:
            print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool")
47
            self._available.append(infer_task.release_kvcache())
Pan Zezhong's avatar
Pan Zezhong committed
48
49
            self._not_empty.notify()

Pan Zezhong's avatar
Pan Zezhong committed
50
51
52
53
54
55
56
57
    async def acquire(self, infer_task):
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.acquire_sync, infer_task)

    async def release(self, infer_task):
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.release_sync, infer_task)

Pan Zezhong's avatar
Pan Zezhong committed
58
59
60
    def find_most_matching_cache(self, tokens: List[int]):
        max_match = 0
        max_match_index = 0
Pan Zezhong's avatar
Pan Zezhong committed
61

Pan Zezhong's avatar
Pan Zezhong committed
62
63
64
65
66
        def first_different_index(a_, b_):
            for i_, (x_, y_) in enumerate(zip(a_, b_)):
                if x_ != y_:
                    return i_
            return min(len(a_), len(b_))
Pan Zezhong's avatar
Pan Zezhong committed
67
68

        for i, kvcache in enumerate(self._available):
Pan Zezhong's avatar
Pan Zezhong committed
69
            common_elements = first_different_index(tokens, kvcache.tokens)
Pan Zezhong's avatar
Pan Zezhong committed
70
71
            # print(f"{tokens}")
            # print(f"{kvcache.tokens[:len(tokens)]}")
Pan Zezhong's avatar
Pan Zezhong committed
72
73
74
75
76
77
            if common_elements > max_match:
                max_match = common_elements
                max_match_index = i

        return (min(max_match, len(tokens) - 1), max_match_index)

Pan Zezhong's avatar
Pan Zezhong committed
78
79
    def finalize(self):
        with self._not_empty:
Pan Zezhong's avatar
Pan Zezhong committed
80
81
            self._shutdown = True
            while len(self._available) < self.num_caches:
Pan Zezhong's avatar
Pan Zezhong committed
82
                self._not_empty.wait()
Pan Zezhong's avatar
Pan Zezhong committed
83
84
85
86
87
88
89
90

            for kvcache in self._available:
                if kvcache is not None:
                    kvcache.drop(self.model)

            self._available.clear()
            self.max_caches = 0
            self.num_caches = 0