Commit 15d855b9 authored by zhouxiang's avatar zhouxiang
Browse files

修复fastllm内存问题,修改repeat_penalty默认值

parent f554b7d6
...@@ -34,7 +34,8 @@ fastllm_lib.fetch_response_logits_llm_model.restype = ctypes.c_int ...@@ -34,7 +34,8 @@ fastllm_lib.fetch_response_logits_llm_model.restype = ctypes.c_int
fastllm_lib.response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_char_p, fastllm_lib.response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_char_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool] ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_str_llm_model.restype = ctypes.c_char_p # fastllm_lib.response_str_llm_model.restype = ctypes.c_char_p
fastllm_lib.response_str_llm_model.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
...@@ -42,20 +43,24 @@ fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char ...@@ -42,20 +43,24 @@ fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char
fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int
fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
fastllm_lib.fetch_response_str_llm_model.restype = ctypes.c_char_p # fastllm_lib.fetch_response_str_llm_model.restype = ctypes.c_char_p
fastllm_lib.fetch_response_str_llm_model.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p] fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p # fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p
fastllm_lib.make_history_llm_model.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.make_input_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p] fastllm_lib.make_input_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p]
fastllm_lib.make_input_llm_model.restype = ctypes.c_char_p # fastllm_lib.make_input_llm_model.restype = ctypes.c_char_p
fastllm_lib.make_input_llm_model.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_float, ctypes.c_int] fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_float, ctypes.c_int]
fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int] fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int]
fastllm_lib.get_llm_model_type.restype = ctypes.c_char_p # fastllm_lib.get_llm_model_type.restype = ctypes.c_char_p
fastllm_lib.get_llm_model_type.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
...@@ -67,6 +72,11 @@ fastllm_lib.response_batch_tokens_llm_model.argtypes = [ctypes.c_int, ctypes.c_i ...@@ -67,6 +72,11 @@ fastllm_lib.response_batch_tokens_llm_model.argtypes = [ctypes.c_int, ctypes.c_i
ctypes.c_float, ctypes.c_float, ctypes.c_bool] ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_tokens_llm_model.restype = ctypes.POINTER(ctypes.c_char_p) fastllm_lib.response_batch_tokens_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)
fastllm_lib.freeChars.argtype = [ctypes.POINTER(ctypes.c_char)]
# fastllm_lib.freeChars.restype = ctypes.c_char_p
fastllm_lib.freeCharArray.argtype = [ctypes.POINTER(ctypes.c_char_p)]
def set_cpu_threads(threads: int): def set_cpu_threads(threads: int):
fastllm_lib.set_cpu_threads(threads); fastllm_lib.set_cpu_threads(threads);
...@@ -134,7 +144,9 @@ class model: ...@@ -134,7 +144,9 @@ class model:
# 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间 # 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
self.tokenizer_decode_token_cache = None self.tokenizer_decode_token_cache = None
self.model_type = fastllm_lib.get_llm_model_type(self.model).decode() model_type_ptr = fastllm_lib.get_llm_model_type(self.model)
self.model_type = ctypes.string_at(model_type_ptr).decode()
fastllm_lib.freeChars(model_type_ptr)
# print("model_type:", self.model_type) # print("model_type:", self.model_type)
def get_prompt(self, def get_prompt(self,
...@@ -144,8 +156,13 @@ class model: ...@@ -144,8 +156,13 @@ class model:
history = []; history = [];
prompt = ""; prompt = "";
for i, (old_query, response) in enumerate(history): for i, (old_query, response) in enumerate(history):
prompt = fastllm_lib.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode()).decode(); history_ptr = fastllm_lib.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode())
prompt = fastllm_lib.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode()).decode(); prompt = ctypes.string_at(history_ptr).decode()
fastllm_lib.freeChars(history_ptr)
input_ptr = fastllm_lib.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode())
prompt = ctypes.string_at(input_ptr).decode()
fastllm_lib.freeChars(input_ptr)
return prompt; return prompt;
def save(self, path : str): def save(self, path : str):
...@@ -241,7 +258,7 @@ class model: ...@@ -241,7 +258,7 @@ class model:
def response(self, def response(self,
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0) -> str: max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1) -> str:
ret = ""; ret = "";
for i in self.stream_response(query = query, for i in self.stream_response(query = query,
history = history, history = history,
...@@ -257,7 +274,7 @@ class model: ...@@ -257,7 +274,7 @@ class model:
def stream_response(self, def stream_response(self,
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
one_by_one = True): one_by_one = True):
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
...@@ -267,10 +284,13 @@ class model: ...@@ -267,10 +284,13 @@ class model:
ret = b''; ret = b'';
fail_cnt = 0; fail_cnt = 0;
while True: while True:
ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle); # ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
ret_chararry = fastllm_lib.fetch_response_str_llm_model(self.model, handle);
ret += ctypes.string_at(ret_chararry)
fastllm_lib.freeChars(ret_chararry)
cur = ""; cur = "";
try: try:
cur = ret.decode(); cur = ret.decode()
ret = b''; ret = b'';
except: except:
fail_cnt += 1; fail_cnt += 1;
...@@ -289,7 +309,7 @@ class model: ...@@ -289,7 +309,7 @@ class model:
def stream_response_raw(self, def stream_response_raw(self,
input_tokens: List[int], input_tokens: List[int],
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
one_by_one = True one_by_one = True
): ):
handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens), handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
...@@ -315,7 +335,7 @@ class model: ...@@ -315,7 +335,7 @@ class model:
yield total_bytes yield total_bytes
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, **kwargs):
if self.model_type != "chatglm3": if self.model_type != "chatglm3":
if (not(history)): if (not(history)):
history = []; history = [];
...@@ -356,7 +376,7 @@ class model: ...@@ -356,7 +376,7 @@ class model:
return response, new_history return response, new_history
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None, def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
return_past_key_values = False, **kwargs) -> str: return_past_key_values = False, **kwargs) -> str:
if self.model_type != "chatglm3": if self.model_type != "chatglm3":
if (not(history)): if (not(history)):
...@@ -445,7 +465,7 @@ class model: ...@@ -445,7 +465,7 @@ class model:
def response_batch(self, querys: List[str], def response_batch(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None, historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
**kwargs) -> List[str]: **kwargs) -> List[str]:
query_size = len(querys) query_size = len(querys)
if (not(historys)): if (not(historys)):
...@@ -463,10 +483,11 @@ class model: ...@@ -463,10 +483,11 @@ class model:
response = ctypes.string_at(outputs[i]).decode() response = ctypes.string_at(outputs[i]).decode()
responses.append(response) responses.append(response)
historys[i] = historys[i] + [(querys[i], response)] historys[i] = historys[i] + [(querys[i], response)]
fastllm_lib.freeCharArray(outputs)
return responses, historys return responses, historys
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024, def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, **kwargs):
query_size = len(querys) query_size = len(querys)
if (not(historys)): if (not(historys)):
historys = [[] for _ in range(query_size)] historys = [[] for _ in range(query_size)]
...@@ -490,6 +511,7 @@ class model: ...@@ -490,6 +511,7 @@ class model:
response = ctypes.string_at(outputs[i]).decode() response = ctypes.string_at(outputs[i]).decode()
responses.append(response) responses.append(response)
historys[i] = historys[i] + [(querys[i], response)] historys[i] = historys[i] + [(querys[i], response)]
fastllm_lib.freeCharArray(outputs)
return responses, historys return responses, historys
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment