infer_task.py 1.62 KB
Newer Older
Pan Zezhong's avatar
Pan Zezhong committed
1
class InferTask:
Pan Zezhong's avatar
Pan Zezhong committed
2
3
4
5
6
7
8
9
10
    def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
        self.id = id
        self.finish_reason = None
        self.tokens = tokens
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.topk = topk
        self.topp = topp
        self.end_tokens = end_tokens
11
        self._kv_cache = None
Pan Zezhong's avatar
Pan Zezhong committed
12
        self.pos = 0
Pan Zezhong's avatar
Pan Zezhong committed
13

14
15
    def bind_kvcache(self, kv_cache, pos=0):
        self._kv_cache = kv_cache
Pan Zezhong's avatar
Pan Zezhong committed
16
17
        self.pos = pos
        self.tokens = self.tokens[pos:]
Pan Zezhong's avatar
Pan Zezhong committed
18

19
20
21
22
23
    def release_kvcache(self):
        cache = self._kv_cache
        self._kv_cache = None
        return cache

Pan Zezhong's avatar
Pan Zezhong committed
24
    def kvcache(self):
25
        return self._kv_cache
Pan Zezhong's avatar
Pan Zezhong committed
26

27
28
    def next(self, out_token):
        self._kv_cache.update_tokens(self.tokens, self.pos)
Pan Zezhong's avatar
Pan Zezhong committed
29
30
31
32
33
34
35
36
37

        self.pos += len(self.tokens)
        if out_token == None or out_token in self.end_tokens:
            self.finish_reason = "stop"
        elif self.pos >= self.max_tokens:
            self.finish_reason = "length"
        else:
            self.tokens = [out_token]

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

class KVCache:
    def __init__(self, model):
        self._kvcache = model.create_kv_cache()
        self.tokens = [0 for _ in range(model.max_context_len())]

    def data(self):
        return self._kvcache

    def drop(self, model):
        model.drop_kv_cache(self._kvcache)

    def update_tokens(self, tokens, pos):
        end = pos + len(tokens)
        max_len = len(self.tokens)

        # If overflow, truncate tokens to fit
        if end > max_len:
            tokens = tokens[: max_len - pos]
            end = max_len

        self.tokens[pos:end] = tokens