llm.py 30.8 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
import ctypes;
import math
import os;
import threading
from typing import Optional, Tuple, Union, List, Callable, Dict, Any;
6
from copy import deepcopy
7
import json
8
9
10

import platform
if platform.system() == 'Windows':
11
12
13
    fastllm_lib = ctypes.CDLL(os.path.join(os.path.split(os.path.realpath(__file__))[0], "fastllm_tools.dll"), winmode=0)
elif platform.system() == 'Darwin':
    fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.dylib"))
14
15
16
17
18
19
else:
    fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.so"))

fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
fastllm_lib.create_llm_model.restype = ctypes.c_int

zhouxiang's avatar
zhouxiang committed
20
21
22
23
24
25
fastllm_lib.token_decode.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]
fastllm_lib.token_decode.restype = ctypes.c_int

fastllm_lib.token_encode_string.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.token_encode_string.restype = ctypes.c_int

26
27
fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
                                                  ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
28
29
                                                  ctypes.c_float, ctypes.c_float, ctypes.c_bool,
                                                  ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
30
31
32
33
34
35
36
37
38
39
40
fastllm_lib.launch_response_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
fastllm_lib.fetch_response_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_logits_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_float)]
fastllm_lib.fetch_response_logits_llm_model.restype = ctypes.c_int

fastllm_lib.response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_char_p,
                                               ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
                                               ctypes.c_float, ctypes.c_float, ctypes.c_bool]
41
42
# fastllm_lib.response_str_llm_model.restype = ctypes.c_char_p
fastllm_lib.response_str_llm_model.restype = ctypes.POINTER(ctypes.c_char)
43
44
45

fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
                                                     ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
46
47
                                                     ctypes.c_float, ctypes.c_float, ctypes.c_bool,
                                                     ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
48
49
50
fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
51
52
# fastllm_lib.fetch_response_str_llm_model.restype = ctypes.c_char_p
fastllm_lib.fetch_response_str_llm_model.restype = ctypes.POINTER(ctypes.c_char)
53
54

fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
55
56
# fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p
fastllm_lib.make_history_llm_model.restype = ctypes.POINTER(ctypes.c_char)
57
58

fastllm_lib.make_input_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p]
59
60
# fastllm_lib.make_input_llm_model.restype = ctypes.c_char_p
fastllm_lib.make_input_llm_model.restype = ctypes.POINTER(ctypes.c_char)
61
62
63
64
65

fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_float, ctypes.c_int]

fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]

66
fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int]
67
fastllm_lib.get_llm_model_type.restype = ctypes.POINTER(ctypes.c_char)
68
69
70
71
72
73
74
75
76
77
78

fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
                                                     ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
                                                     ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_str_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)

fastllm_lib.response_batch_tokens_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
                                                        ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
                                                        ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_tokens_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)

79
80
81
82
83
fastllm_lib.freeChars.argtype = [ctypes.POINTER(ctypes.c_char)]
# fastllm_lib.freeChars.restype = ctypes.c_char_p

fastllm_lib.freeCharArray.argtype = [ctypes.POINTER(ctypes.c_char_p)]

84
def set_cpu_threads(threads: int):
zhouxiang's avatar
zhouxiang committed
85
    fastllm_lib.set_cpu_threads(threads);
86
87

def get_cpu_threads() -> int:
zhouxiang's avatar
zhouxiang committed
88
    return fastllm_lib.get_cpu_threads();
89
90

def print_ins_info():
zhouxiang's avatar
zhouxiang committed
91
    fastllm_lib.print_cpu_ins();
92
93

def set_cpu_kvcache(cpu_kvcache):
zhouxiang's avatar
zhouxiang committed
94
    fastllm_lib.set_kvcache_in_cpu(ctypes.c_bool(cpu_kvcache));
95
96

def get_cpu_kvcache():
zhouxiang's avatar
zhouxiang committed
97
    return fastllm_lib.get_kvcache_in_cpu();
98
99

def set_cpu_low_mem(low_mem):
zhouxiang's avatar
zhouxiang committed
100
    fastllm_lib.set_cpu_low_mem(ctypes.c_bool(low_mem));
101
102

def get_cpu_low_mem():
zhouxiang's avatar
zhouxiang committed
103
    return fastllm_lib.get_cpu_low_mem();
104
105

def set_device_map(device_map):
zhouxiang's avatar
zhouxiang committed
106
107
    devices = [];
    values = [];
108
    if (isinstance(device_map, str)):
zhouxiang's avatar
zhouxiang committed
109
110
        devices.append(device_map);
        values.append(1);
111
    elif (isinstance(device_map, list)):
zhouxiang's avatar
zhouxiang committed
112
113
        devices = [str(x) for x in device_map];
        values = [1 for x in device_map];
114
    elif (isinstance(device_map, dict)):
zhouxiang's avatar
zhouxiang committed
115
116
        devices = [str(x) for x in device_map.keys()];
        values = [int(device_map[x]) for x in device_map.keys()];
117
    else:
zhouxiang's avatar
zhouxiang committed
118
119
120
121
        print("set_device_map error.");
        return;
    device_str = ''.join(devices);
    device_len = [len(x) for x in devices];
122
123
124
    fastllm_lib.set_device_map(len(device_len),
                               (ctypes.c_int * len(device_len))(*device_len),
                               device_str.encode(),
zhouxiang's avatar
zhouxiang committed
125
                               (ctypes.c_int * len(values))(*values));
126
127
128
def from_hf(model,
            tokenizer = None,
            dtype = "float16"):
zhouxiang's avatar
zhouxiang committed
129
130
    from fastllm_pytools import hf_model;
    return hf_model.create(model, tokenizer, dtype = dtype);
131
132
133
134
135

class model:
    def __init__ (self, path : str,
                  id : int = -99999):
        if (id != -99999):
zhouxiang's avatar
zhouxiang committed
136
            self.model = id;
137
        else:
zhouxiang's avatar
zhouxiang committed
138
139
140
141
142
143
144
145
146
147
148
149
            self.model = fastllm_lib.create_llm_model(path.encode());
        self.direct_query = False;

        # 为了减少重复申请释放buffer对象而使用的线程局部存储区对象池
        self.thread_local_obj = threading.local()
        self.thread_local_obj.tokenizer_encode_string__output_buffer = None
        self.thread_local_obj.tokenizer_decode_token__output_buffer = None

        # tokenizer_decode_token 输出结果的静态缓存,手工触发构建
        # 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。
        # 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
        self.tokenizer_decode_token_cache = None
150

151
152
153
        model_type_ptr = fastllm_lib.get_llm_model_type(self.model)
        self.model_type = ctypes.string_at(model_type_ptr).decode()
        fastllm_lib.freeChars(model_type_ptr)
154
155
        # print("model_type:", self.model_type)

156
157
158
159
    def get_prompt(self,
                   query: str,
                   history: List[Tuple[str, str]] = None) -> str:
        if (not(history)):
zhouxiang's avatar
zhouxiang committed
160
161
            history = [];
        prompt = "";
162
        for i, (old_query, response) in enumerate(history):
163
164
165
166
167
168
169
            history_ptr = fastllm_lib.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode())
            prompt = ctypes.string_at(history_ptr).decode()
            fastllm_lib.freeChars(history_ptr)
        
        input_ptr = fastllm_lib.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode())
        prompt = ctypes.string_at(input_ptr).decode()
        fastllm_lib.freeChars(input_ptr)
zhouxiang's avatar
zhouxiang committed
170
        return prompt;
171
172

    def save(self, path : str):
zhouxiang's avatar
zhouxiang committed
173
        fastllm_lib.save_llm_model(self.model, path.encode());
174
175

    def eval(self):
zhouxiang's avatar
zhouxiang committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        pass;

    def build_tokenizer_decode_token_cache(self):
        if self.tokenizer_decode_token_cache is not None:
            return

        cache_dict = dict()
        vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model)
        for token_id in range(vocab_size):
            cache_dict[token_id] = self.tokenizer_decode_token(token_id)

        self.tokenizer_decode_token_cache = cache_dict

    def tokenizer_encode_string(self, content: str) -> List[int]:
        output_buffer_init_len = 1024
191
        if not hasattr(self.thread_local_obj, 'tokenizer_encode_string__output_buffer') or self.thread_local_obj.tokenizer_encode_string__output_buffer is None:
zhouxiang's avatar
zhouxiang committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)()

        buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer
        buffer_len = len(buffer)
        result_len = fastllm_lib.token_encode_string(self.model, content.encode(), buffer_len, buffer)
        if result_len > buffer_len:
            if result_len > 10240:
                # 要处理的数据过长,使用一次性的buffer
                temp_buffer = (ctypes.c_int * result_len)()
                ret = fastllm_lib.token_encode_string(self.model, content.encode(), result_len, temp_buffer)
                return [i for i in temp_buffer]
            else:
                # 扩展buffer大小
                new_buffer_len = round(math.ceil(result_len / 1024.0)) * 1024
                buffer = (ctypes.c_int * new_buffer_len)()
                self.thread_local_obj.tokenizer_encode_string__output_buffer = buffer
                result_len = fastllm_lib.token_encode_string(self.model, content.encode(), new_buffer_len, buffer)

        return [buffer[i] for i in range(result_len)]

    def tokenizer_decode_token(self, token_id: int) -> bytes:
        if self.tokenizer_decode_token_cache is not None:
            cache_result = self.tokenizer_decode_token_cache.get(token_id)
            if cache_result is not None:
                return cache_result

        output_buffer_init_len = 256
219
        if not hasattr(self.thread_local_obj, 'tokenizer_decode_token__output_buffer') or self.thread_local_obj.tokenizer_decode_token__output_buffer is None:
zhouxiang's avatar
zhouxiang committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            self.thread_local_obj.tokenizer_decode_token__output_buffer = ctypes.create_string_buffer(output_buffer_init_len)

        buffer = self.thread_local_obj.tokenizer_decode_token__output_buffer
        ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer)
        if ret > 0:
            # buffer长度不够,扩展buffer大小
            new_buffer_len = round(math.ceil(ret / 16.0)) * 16
            buffer = ctypes.create_string_buffer(new_buffer_len)
            self.thread_local_obj.tokenizer_decode_token__output_buffer = buffer
            ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer)
            assert ret == 0

        buffer_bytes = buffer.raw
        result_len = len(buffer_bytes)
        for i in range(len(buffer_bytes)):
            if buffer_bytes[i] == 0:
                result_len = i
                break
        return buffer_bytes[:result_len]
239

240
241
242
243
244
245
    def stop_token_ctypes(self, stop_token_ids):
        if stop_token_ids is None:
            return 0, None
        else:
            return ctypes.c_int(len(stop_token_ids)), (ctypes.c_int * len(stop_token_ids))(*stop_token_ids)
        
246
247
248
    def response_logits(self,
                        query: str,
                        history: List[Tuple[str, str]] = None,
249
250
251
                        tokenizer = None,
                        stop_token_ids: List[int] = None,
                        ) -> str:
zhouxiang's avatar
zhouxiang committed
252
        prompt = query if self.direct_query else self.get_prompt(query, history);
253
        stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
254
255
        if (tokenizer == None):
            handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
256
257
258
                                                           ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
                                                           ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True),
                                                           stop_token_len, stop_token_list);
259
        else:
zhouxiang's avatar
zhouxiang committed
260
            input = tokenizer.encode(prompt);
261
            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
262
                                                           1, False, 1, 1, 1, 1, True, stop_token_len, stop_token_list);
zhouxiang's avatar
zhouxiang committed
263
        vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model);
264
        logits = list(range(vocab_size))
zhouxiang's avatar
zhouxiang committed
265
266
267
        array = (ctypes.c_float * (vocab_size * 4))(*logits);
        ret = fastllm_lib.fetch_response_logits_llm_model(self.model, handle, array);
        out = list(array)[:vocab_size];
268
        while (ret != -1):
zhouxiang's avatar
zhouxiang committed
269
270
            ret = fastllm_lib.fetch_response_logits_llm_model(self.model, handle, array);
        return out;
271
272
273
274

    def response(self,
                 query: str,
                 history: List[Tuple[str, str]] = None,
275
276
                 max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
                 stop_token_ids: List[int] = None) -> str:
zhouxiang's avatar
zhouxiang committed
277
        ret = "";
278
279
280
281
282
283
284
        for i in self.stream_response(query = query,
                                      history = history,
                                      max_length = max_length,
                                      do_sample = do_sample,
                                      top_p = top_p, top_k = top_k,
                                      temperature = temperature,
                                      repeat_penalty = repeat_penalty,
285
286
                                      one_by_one = True,
                                      stop_token_ids = stop_token_ids):
zhouxiang's avatar
zhouxiang committed
287
288
            ret += i;
        return ret;
289
290
291
292

    def stream_response(self,
                        query: str,
                        history: List[Tuple[str, str]] = None,
293
                        max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
294
                        one_by_one = True, stop_token_ids: List[int] = None):
zhouxiang's avatar
zhouxiang committed
295
        prompt = query if self.direct_query else self.get_prompt(query, history);
296
        stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
297
298
        handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
                                                           ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
299
300
                                                           ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
                                                           stop_token_len, stop_token_list);
zhouxiang's avatar
zhouxiang committed
301
302
303
        res = "";
        ret = b'';
        fail_cnt = 0;
304
        while True:
305
306
307
308
            # ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
            ret_chararry = fastllm_lib.fetch_response_str_llm_model(self.model, handle);
            ret += ctypes.string_at(ret_chararry)
            fastllm_lib.freeChars(ret_chararry)
zhouxiang's avatar
zhouxiang committed
309
            cur = "";
310
            try:
311
                cur = ret.decode()
zhouxiang's avatar
zhouxiang committed
312
                ret = b'';
313
            except:
zhouxiang's avatar
zhouxiang committed
314
                fail_cnt += 1;
315
                if (fail_cnt == 20):
zhouxiang's avatar
zhouxiang committed
316
                    break;
317
                else:
zhouxiang's avatar
zhouxiang committed
318
319
                    continue;
            fail_cnt = 0;
320
            if (cur == "<flmeos>"):
zhouxiang's avatar
zhouxiang committed
321
322
323
324
325
326
327
328
329
                break;
            if one_by_one:
                yield cur;
            else:
                res += cur;
                yield res;

    def stream_response_raw(self,
                            input_tokens: List[int],
330
                            max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
331
332
                            one_by_one = True,
                            stop_token_ids: List[int] = None
zhouxiang's avatar
zhouxiang committed
333
                            ):
334
        stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
zhouxiang's avatar
zhouxiang committed
335
336
337
        handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
                                                       (ctypes.c_int * len(input_tokens))(*input_tokens),
                                                       ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
338
339
                                                       ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
                                                       stop_token_len, stop_token_list)
zhouxiang's avatar
zhouxiang committed
340
341
342
343
344
345
346
347

        # 可能遇到长尾char需要多个token才能够生成,所以只返回bytes,string.decode策略交给外部
        # 方便统计输出token数量,和控制不完整utf8时候解码的逻辑

        total_bytes = b''
        while True:
            cur_token = fastllm_lib.fetch_response_llm_model(self.model, handle)
            if cur_token == -1:
348
                break
zhouxiang's avatar
zhouxiang committed
349
350
351

            cur_bytes = self.tokenizer_decode_token(cur_token)

352
            if one_by_one:
zhouxiang's avatar
zhouxiang committed
353
                yield cur_bytes
354
            else:
zhouxiang's avatar
zhouxiang committed
355
356
                total_bytes += cur_bytes
                yield total_bytes
357
358

    def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
359
             do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, stop_token_ids: List[int] = None, **kwargs):
360
361
362
363
364
        if self.model_type  != "chatglm3":
            if (not(history)):
                history = [];
            prompt = query if self.direct_query else self.get_prompt(query, history);
            input = tokenizer.encode(prompt);
365
            stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
366
367
            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                           max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
368
                                                       	   False, stop_token_len, stop_token_list);
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
            result = [];
            while True:
                cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
                if (cur == -1):
                    break;
                result.append(cur);
            response = tokenizer.decode(result);
            history = history + [(query, response)];
            return response, history;
        else:
            if history is None:
                history = []
            role = "user"
            input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
384
385
            history.append({"role": role, "content": query})			
            stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
386
387
            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                           max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
388
                                                           False, stop_token_len, stop_token_list);
389
390
391
392
393
394
395
396
397
398
            tokens = [];
            while True:
                cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
                if (cur == -1):
                    break;
                tokens.append(cur);
            response = tokenizer.decode(tokens);
            if response and response[-1] != "�":
                response, new_history = self.process_chatglm3_response(response, history)
                return response, new_history
399
400

    def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
401
                    max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
402
                    return_past_key_values = False, stop_token_ids: List[int] = None, **kwargs) -> str:
403
404
405
406
407
        if self.model_type  != "chatglm3":
            if (not(history)):
                history = [];
            prompt = query if self.direct_query else self.get_prompt(query, history);
            input = tokenizer.encode(prompt);
408
            stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
409
410
            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                           max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
411
                                                           False, stop_token_len, stop_token_list);
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
            tokens = [];
            while True:
                cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
                if (cur == -1):
                    break;
                tokens.append(cur);
                response = tokenizer.decode(tokens);
                new_history = history + [(query, response)];
                if return_past_key_values:
                    yield response, new_history, None;
                else:
                    yield response, new_history;
        else:
            if history is None:
                history = []
            role = "user"
            input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
            history.append({"role": role, "content": query})
430
            stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
431
432
433

            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                           max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
434
                                                           False, stop_token_len, stop_token_list);
435
436
437
438
439
440
441
442
443
444
445
446
447
448
            tokens = [];
            while True:
                cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
                if (cur == -1):
                    break;
                tokens.append(cur);
                response = tokenizer.decode(tokens);
                if response and response[-1] != "�":
                    response, new_history = self.process_chatglm3_response(response, history)
                    if return_past_key_values:
                        yield response, new_history, past_key_values
                    else:
                        yield response, new_history

449
450
451

    def set_adapter(self, name: str):
        fastllm_lib.set_adapter(self.model, str(name).encode())
452

453
454
    def disable_adapter(self):
        fastllm_lib.disable_adapter(self.model)
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

    def process_chatglm3_response(self, output, history):
        content = ""
        history = deepcopy(history)
        for response in output.split("<|assistant|>"):
            metadata, content = response.split("\n", maxsplit=1)
            if not metadata.strip():
                content = content.strip()
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                content = content.replace("[[训练时间]]", "2023年")
            else:
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                if history[0]["role"] == "system" and "tools" in history[0]:
                    content = "\n".join(content.split("\n")[1:-1])
                    def tool_call(**kwargs):
                        return kwargs
                    parameters = eval(content)
                    content = {"name": metadata.strip(), "parameters": parameters}
                else:
                    content = {"name": metadata.strip(), "content": content}
        return content, history

    def build_chatglm3_input(self, tokenizer, query, history=None, role="user"):
        if history is None:
            history = []
        input_ids = []
        for item in history:
            content = item["content"]
            if item["role"] == "system" and "tools" in item:
                content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
            input_ids.extend(tokenizer.build_single_message(item["role"], item.get("metadata", ""), content))
        input_ids.extend(tokenizer.build_single_message(role, "", query))
        input_ids.extend([tokenizer.get_command("<|assistant|>")])
        return input_ids

490
    def response_batch_raw(self, querys: List[str],
491
                       historys: List[List[Tuple[str, str]]] = None,
492
                       max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
                       **kwargs) -> List[str]:
        query_size = len(querys)
        if (not(historys)):
            historys = [[] for _ in range(query_size)]
        inputs = (ctypes.c_char_p * query_size)()
        for i, query in enumerate(querys):
            prompt = query if self.direct_query else self.get_prompt(query, historys[i])
            inputs[i] = ctypes.c_char_p(prompt.encode())

        outputs = fastllm_lib.response_batch_str_llm_model(self.model, inputs, query_size,
                                                           max_length, do_sample, top_p, top_k, temperature, repeat_penalty, False)

        responses = []
        for i in range(query_size):
            response = ctypes.string_at(outputs[i]).decode()
            responses.append(response)
            historys[i] = historys[i] + [(querys[i], response)]
zhouxiang's avatar
zhouxiang committed
510
        fastllm_lib.freeCharArray(outputs, query_size)
511
512
        return responses, historys

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    def chat_batch_raw(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
                   do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
        query_size = len(querys)
        if (not(historys)):
            historys = [[] for _ in range(query_size)]

        inputs = []
        inputs_len = []
        for i, query in enumerate(querys):
            prompt = query if self.direct_query else self.get_prompt(query, historys[i])
            input = tokenizer.encode(prompt);
            inputs.extend(input)
            inputs_len.append(len(input))

        outputs = fastllm_lib.response_batch_tokens_llm_model(self.model, query_size,
                                                                (ctypes.c_int * len(inputs_len))(*inputs_len),
                                                                (ctypes.c_int * len(inputs))(*inputs),
                                                                max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
                                                                False)

        responses = []
        for i in range(query_size):
            response = ctypes.string_at(outputs[i]).decode()
            responses.append(response)
            historys[i] = historys[i] + [(querys[i], response)]
        fastllm_lib.freeCharArray(outputs, query_size)
        return responses, historys

    def response_batch(self, querys: List[str],
                       historys: List[List[Tuple[str, str]]] = None,
                       max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
544
                       stop_token_ids: List[int] = None, **kwargs) -> List[str]:
545
546
547
548
549
550
        query_size = len(querys)
        if (not(historys)):
            historys = [[] for _ in range(query_size)]
        handles = []
        for i, query in enumerate(querys):
            prompt = query if self.direct_query else self.get_prompt(query, historys[i])
551
            stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
552
553
            handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
                                                           ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
554
555
                                                           ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
                                                           stop_token_len, stop_token_list)
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            handles.append(handle)

        responses = []
        for i, handle in enumerate(handles):
            res = ""
            ret = b''
            fail_cnt = 0
            while True:
                # ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
                ret_chararry = fastllm_lib.fetch_response_str_llm_model(self.model, handle);
                ret += ctypes.string_at(ret_chararry)
                fastllm_lib.freeChars(ret_chararry)
                cur = ""
                try:
                    cur = ret.decode()
                    ret = b''
                except:
                    fail_cnt += 1
                    if (fail_cnt == 20):
                        break
                    else:
                        continue
                fail_cnt = 0
                if (cur == "<flmeos>"):
                    break;
                res += cur
            responses.append(res)
            historys[i] = historys[i] + [(querys[i], res)]

        return responses, historys
   

588
    def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
589
                   do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, stop_token_ids: List[int] = None, **kwargs):
590
591
592
593
        query_size = len(querys)
        if (not(historys)):
            historys = [[] for _ in range(query_size)]

594
595
596
597
        handles = []
        for i, query in enumerate(querys):
            prompt = query if self.direct_query else self.get_prompt(query, historys[i])
            input = tokenizer.encode(prompt);
598
            stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
599
600
            handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
                                                           max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
601
                                                           False, stop_token_len, stop_token_list);
602
            handles.append(handle)
603

604
605
606
607
608
609
610
611
612
613
614
        responses = []
        for i, handle in enumerate(handles):
            result = [];
            while True:
                cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
                if (cur == -1):
                    break;
                result.append(cur);
            response = tokenizer.decode(result);
            responses.append(response)
            historys[i] = historys[i] + [(querys[i], response)]
615

616
617
        return responses, historys
    
618
619
620
    def release_memory(self):
        fastllm_lib.release_memory(self.model)