infer_task.py 1.16 KB
Newer Older
Pan Zezhong's avatar
Pan Zezhong committed
1
import janus
Pan Zezhong's avatar
Pan Zezhong committed
2
3
4


class InferTask:
Pan Zezhong's avatar
Pan Zezhong committed
5
6
7
8
9
10
11
12
13
14
    def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
        self.id = id
        self.finish_reason = None
        self.tokens = tokens
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.topk = topk
        self.topp = topp
        self.end_tokens = end_tokens
        self.output_queue = janus.Queue()
Pan Zezhong's avatar
Pan Zezhong committed
15
16
        self._kv_cache_pool_item = None
        self.pos = 0
Pan Zezhong's avatar
Pan Zezhong committed
17
18
        print(f"[INFO] Create InferTask {self.id}")

Pan Zezhong's avatar
Pan Zezhong committed
19
20
21
22
    def bind_kvcache(self, kv_cache_pool_item, pos):
        self._kv_cache_pool_item = kv_cache_pool_item
        self.pos = pos
        self.tokens = self.tokens[pos:]
Pan Zezhong's avatar
Pan Zezhong committed
23

Pan Zezhong's avatar
Pan Zezhong committed
24
25
26
    def kvcache(self):
        return self._kv_cache_pool_item.kvcache

Pan Zezhong's avatar
Pan Zezhong committed
27
28
29
30
31
32
33
34
35
36
37
38
    def output(self, out_token):
        self._kv_cache_pool_item.update_tokens(self.tokens, self.pos)

        self.pos += len(self.tokens)
        if out_token == None or out_token in self.end_tokens:
            self.finish_reason = "stop"
        elif self.pos >= self.max_tokens:
            self.finish_reason = "length"
        else:
            self.tokens = [out_token]

        self.output_queue.sync_q.put(out_token)