llm.py 10.8 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
import ctypes
import os
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import platform
if platform.system() == 'Windows':
    fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "fastllm_tools.dll"))
else:
    fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.so"))

fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
fastllm_lib.create_llm_model.restype = ctypes.c_int

fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
                                                  ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
                                                  ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.launch_response_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
fastllm_lib.fetch_response_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_logits_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_float)]
fastllm_lib.fetch_response_logits_llm_model.restype = ctypes.c_int

fastllm_lib.response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_char_p,
                                               ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
                                               ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_str_llm_model.restype = ctypes.c_char_p

fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
                                                     ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
                                                     ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
fastllm_lib.fetch_response_str_llm_model.restype = ctypes.c_char_p

fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p

fastllm_lib.make_input_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p]
fastllm_lib.make_input_llm_model.restype = ctypes.c_char_p

fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_float, ctypes.c_int]

fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]

def set_cpu_threads(threads: int):
zhouxiang's avatar
zhouxiang committed
49
    fastllm_lib.set_cpu_threads(threads)
50
51

def get_cpu_threads() -> int:
zhouxiang's avatar
zhouxiang committed
52
    return fastllm_lib.get_cpu_threads()
53
54

def print_ins_info():
zhouxiang's avatar
zhouxiang committed
55
    fastllm_lib.print_cpu_ins()
56
57

def set_cpu_kvcache(cpu_kvcache):
zhouxiang's avatar
zhouxiang committed
58
    fastllm_lib.set_kvcache_in_cpu(ctypes.c_bool(cpu_kvcache))
59
60

def get_cpu_kvcache():
zhouxiang's avatar
zhouxiang committed
61
    return fastllm_lib.get_kvcache_in_cpu()
62
63

def set_cpu_low_mem(low_mem):
zhouxiang's avatar
zhouxiang committed
64
    fastllm_lib.set_cpu_low_mem(ctypes.c_bool(low_mem))
65
66

def get_cpu_low_mem():
zhouxiang's avatar
zhouxiang committed
67
    return fastllm_lib.get_cpu_low_mem()
68
69

def set_device_map(device_map):
zhouxiang's avatar
zhouxiang committed
70
71
    devices = []
    values = []
72
    if (isinstance(device_map, str)):
zhouxiang's avatar
zhouxiang committed
73
74
        devices.append(device_map)
        values.append(1)
75
    elif (isinstance(device_map, list)):
zhouxiang's avatar
zhouxiang committed
76
77
        devices = [str(x) for x in device_map]
        values = [1 for x in device_map]
78
    elif (isinstance(device_map, dict)):
zhouxiang's avatar
zhouxiang committed
79
80
        devices = [str(x) for x in device_map.keys()]
        values = [int(device_map[x]) for x in device_map.keys()]
81
    else:
zhouxiang's avatar
zhouxiang committed
82
83
84
85
        print("set_device_map error.")
        return
    device_str = ''.join(devices)
    device_len = [len(x) for x in devices]
86
87
88
    fastllm_lib.set_device_map(len(device_len),
                               (ctypes.c_int * len(device_len))(*device_len),
                               device_str.encode(),
zhouxiang's avatar
zhouxiang committed
89
                               (ctypes.c_int * len(values))(*values))
90
91
92
def from_hf(model,
            tokenizer = None,
            dtype = "float16"):
zhouxiang's avatar
zhouxiang committed
93
94
    from fastllm_pytools import hf_model
    return hf_model.create(model, tokenizer, dtype = dtype)
95
96
97
98
99

class model:
    def __init__ (self, path : str,
                  id : int = -99999):
        if (id != -99999):
zhouxiang's avatar
zhouxiang committed
100
            self.model = id
101
        else:
zhouxiang's avatar
zhouxiang committed
102
103
            self.model = fastllm_lib.create_llm_model(path.encode())
        self.direct_query = False
104
105
106
107
108

    def get_prompt(self,
                   query: str,
                   history: List[Tuple[str, str]] = None) -> str:
        if (not(history)):
zhouxiang's avatar
zhouxiang committed
109
110
            history = []
        prompt = ""
111
        for i, (old_query, response) in enumerate(history):
zhouxiang's avatar
zhouxiang committed
112
113
114
            prompt = fastllm_lib.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode()).decode()
        prompt = fastllm_lib.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode()).decode()
        return prompt
115
116

    def save(self, path : str):
zhouxiang's avatar
zhouxiang committed
117
        fastllm_lib.save_llm_model(self.model, path.encode())
118
119

    def eval(self):
zhouxiang's avatar
zhouxiang committed
120
        pass
121
122
123
124
125

    def response_logits(self,
                        query: str,
                        history: List[Tuple[str, str]] = None,
                        tokenizer = None) -> str:
zhouxiang's avatar
zhouxiang committed
126
        prompt = query if self.direct_query else self.get_prompt(query, history)
127
128
129
        if (tokenizer == None):
            handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
                                                           ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
zhouxiang's avatar
zhouxiang committed
130
                                                           ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True))
131
        else:
zhouxiang's avatar
zhouxiang committed
132
            input = tokenizer.encode(prompt)
133
            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
zhouxiang's avatar
zhouxiang committed
134
135
                                                           1, False, 1, 1, 1, 1, True)
        vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model)
136
        logits = list(range(vocab_size))
zhouxiang's avatar
zhouxiang committed
137
138
139
        array = (ctypes.c_float * (vocab_size * 4))(*logits)
        ret = fastllm_lib.fetch_response_logits_llm_model(self.model, handle, array)
        out = list(array)[:vocab_size]
140
        while (ret != -1):
zhouxiang's avatar
zhouxiang committed
141
142
            ret = fastllm_lib.fetch_response_logits_llm_model(self.model, handle, array)
        return out
143
144
145
146
147

    def response(self,
                 query: str,
                 history: List[Tuple[str, str]] = None,
                 max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0) -> str:
zhouxiang's avatar
zhouxiang committed
148
        ret = ""
149
150
151
152
153
154
155
156
        for i in self.stream_response(query = query,
                                      history = history,
                                      max_length = max_length,
                                      do_sample = do_sample,
                                      top_p = top_p, top_k = top_k,
                                      temperature = temperature,
                                      repeat_penalty = repeat_penalty,
                                      one_by_one = True):
zhouxiang's avatar
zhouxiang committed
157
158
            ret += i
        return ret
159
160
161
162
163
164

    def stream_response(self,
                        query: str,
                        history: List[Tuple[str, str]] = None,
                        max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
                        one_by_one = True):
zhouxiang's avatar
zhouxiang committed
165
        prompt = query if self.direct_query else self.get_prompt(query, history)
166
167
        handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
                                                           ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
zhouxiang's avatar
zhouxiang committed
168
169
170
171
                                                           ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False))
        res = ""
        ret = b''
        fail_cnt = 0
172
        while True:
zhouxiang's avatar
zhouxiang committed
173
174
            ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle)
            cur = ""
175
            try:
zhouxiang's avatar
zhouxiang committed
176
177
                cur = ret.decode()
                ret = b''
178
            except:
zhouxiang's avatar
zhouxiang committed
179
                fail_cnt += 1
180
                if (fail_cnt == 20):
zhouxiang's avatar
zhouxiang committed
181
                    break
182
                else:
zhouxiang's avatar
zhouxiang committed
183
184
                    continue
            fail_cnt = 0
185
            if (cur == "<flmeos>"):
zhouxiang's avatar
zhouxiang committed
186
                break
187
            if one_by_one:
zhouxiang's avatar
zhouxiang committed
188
                yield cur
189
            else:
zhouxiang's avatar
zhouxiang committed
190
191
                res += cur
                yield res
192
193
194
195

    def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
             do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
        if (not(history)):
zhouxiang's avatar
zhouxiang committed
196
197
198
            history = []
        prompt = query if self.direct_query else self.get_prompt(query, history)
        input = tokenizer.encode(prompt)
199
200
        handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                       max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
zhouxiang's avatar
zhouxiang committed
201
                                                       False)
202

zhouxiang's avatar
zhouxiang committed
203
        result = []
204
        while True:
zhouxiang's avatar
zhouxiang committed
205
            cur = fastllm_lib.fetch_response_llm_model(self.model, handle)
206
            if (cur == -1):
zhouxiang's avatar
zhouxiang committed
207
208
209
210
211
                break
            result.append(cur)
        response = tokenizer.decode(result)
        history = history + [(query, response)]
        return response, history
212
213
214
215
216

    def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
                    max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
                    return_past_key_values = False, **kwargs) -> str:
        if (not(history)):
zhouxiang's avatar
zhouxiang committed
217
218
219
            history = []
        prompt = query if self.direct_query else self.get_prompt(query, history)
        input = tokenizer.encode(prompt)
220
221
        handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                       max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
zhouxiang's avatar
zhouxiang committed
222
223
                                                       False)
        tokens = []
224
        while True:
zhouxiang's avatar
zhouxiang committed
225
            cur = fastllm_lib.fetch_response_llm_model(self.model, handle)
226
            if (cur == -1):
zhouxiang's avatar
zhouxiang committed
227
228
229
230
                break
            tokens.append(cur)
            response = tokenizer.decode(tokens)
            new_history = history + [(query, response)]
231
            if return_past_key_values:
zhouxiang's avatar
zhouxiang committed
232
                yield response, new_history, None
233
            else:
zhouxiang's avatar
zhouxiang committed
234
235
236
237
238
239
240
                yield response, new_history

    def set_adapter(self, name: str):
        fastllm_lib.set_adapter(self.model, str(name).encode())
    
    def disable_adapter(self):
        fastllm_lib.disable_adapter(self.model)