Commit f554b7d6 authored by zhouxiang's avatar zhouxiang
Browse files

解决baichuan流式输出不均匀的问题

parent 53e02bdd
...@@ -26,6 +26,8 @@ def create(model, ...@@ -26,6 +26,8 @@ def create(model,
exit(0); exit(0);
# 0.1 model info # 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3"
modelInfo = model.config.__dict__ modelInfo = model.config.__dict__
if model.generation_config is not None: if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__) modelInfo.update(model.generation_config.__dict__)
......
...@@ -3,6 +3,7 @@ import math ...@@ -3,6 +3,7 @@ import math
import os; import os;
import threading import threading
from typing import Optional, Tuple, Union, List, Callable, Dict, Any; from typing import Optional, Tuple, Union, List, Callable, Dict, Any;
from copy import deepcopy
import platform import platform
if platform.system() == 'Windows': if platform.system() == 'Windows':
...@@ -53,6 +54,19 @@ fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_ ...@@ -53,6 +54,19 @@ fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_
fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int]
fastllm_lib.get_llm_model_type.restype = ctypes.c_char_p
fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_str_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)
fastllm_lib.response_batch_tokens_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_tokens_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)
def set_cpu_threads(threads: int): def set_cpu_threads(threads: int):
fastllm_lib.set_cpu_threads(threads); fastllm_lib.set_cpu_threads(threads);
...@@ -120,6 +134,9 @@ class model: ...@@ -120,6 +134,9 @@ class model:
# 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间 # 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
self.tokenizer_decode_token_cache = None self.tokenizer_decode_token_cache = None
self.model_type = fastllm_lib.get_llm_model_type(self.model).decode()
# print("model_type:", self.model_type)
def get_prompt(self, def get_prompt(self,
query: str, query: str,
history: List[Tuple[str, str]] = None) -> str: history: List[Tuple[str, str]] = None) -> str:
...@@ -206,8 +223,8 @@ class model: ...@@ -206,8 +223,8 @@ class model:
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
if (tokenizer == None): if (tokenizer == None):
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1), ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True)); ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True));
else: else:
input = tokenizer.encode(prompt); input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
...@@ -299,49 +316,180 @@ class model: ...@@ -299,49 +316,180 @@ class model:
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
if (not(history)): if self.model_type != "chatglm3":
history = []; if (not(history)):
prompt = query if self.direct_query else self.get_prompt(query, history); history = [];
input = tokenizer.encode(prompt); prompt = query if self.direct_query else self.get_prompt(query, history);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), input = tokenizer.encode(prompt);
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
False); max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
result = []; result = [];
while True: while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle); cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1): if (cur == -1):
break; break;
result.append(cur); result.append(cur);
response = tokenizer.decode(result); response = tokenizer.decode(result);
history = history + [(query, response)]; history = history + [(query, response)];
return response, history; return response, history;
else:
if history is None:
history = []
role = "user"
input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
history.append({"role": role, "content": query})
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
if response and response[-1] != "�":
response, new_history = self.process_chatglm3_response(response, history)
return response, new_history
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None, def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
return_past_key_values = False, **kwargs) -> str: return_past_key_values = False, **kwargs) -> str:
if (not(history)): if self.model_type != "chatglm3":
history = []; if (not(history)):
prompt = query if self.direct_query else self.get_prompt(query, history); history = [];
input = tokenizer.encode(prompt); prompt = query if self.direct_query else self.get_prompt(query, history);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), input = tokenizer.encode(prompt);
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
False); max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
tokens = []; False);
while True: tokens = [];
cur = fastllm_lib.fetch_response_llm_model(self.model, handle); while True:
if (cur == -1): cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
break; if (cur == -1):
tokens.append(cur); break;
response = tokenizer.decode(tokens); tokens.append(cur);
new_history = history + [(query, response)]; response = tokenizer.decode(tokens);
if return_past_key_values: new_history = history + [(query, response)];
yield response, new_history, None; if return_past_key_values:
else: yield response, new_history, None;
yield response, new_history; else:
yield response, new_history;
else:
if history is None:
history = []
role = "user"
input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
history.append({"role": role, "content": query})
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
if response and response[-1] != "�":
response, new_history = self.process_chatglm3_response(response, history)
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
def set_adapter(self, name: str): def set_adapter(self, name: str):
fastllm_lib.set_adapter(self.model, str(name).encode()) fastllm_lib.set_adapter(self.model, str(name).encode())
def disable_adapter(self): def disable_adapter(self):
fastllm_lib.disable_adapter(self.model) fastllm_lib.disable_adapter(self.model)
def process_chatglm3_response(self, output, history):
content = ""
history = deepcopy(history)
for response in output.split("<|assistant|>"):
metadata, content = response.split("\n", maxsplit=1)
if not metadata.strip():
content = content.strip()
history.append({"role": "assistant", "metadata": metadata, "content": content})
content = content.replace("[[训练时间]]", "2023年")
else:
history.append({"role": "assistant", "metadata": metadata, "content": content})
if history[0]["role"] == "system" and "tools" in history[0]:
content = "\n".join(content.split("\n")[1:-1])
def tool_call(**kwargs):
return kwargs
parameters = eval(content)
content = {"name": metadata.strip(), "parameters": parameters}
else:
content = {"name": metadata.strip(), "content": content}
return content, history
def build_chatglm3_input(self, tokenizer, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(tokenizer.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(tokenizer.build_single_message(role, "", query))
input_ids.extend([tokenizer.get_command("<|assistant|>")])
return input_ids
def response_batch(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
**kwargs) -> List[str]:
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
inputs = (ctypes.c_char_p * query_size)()
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
inputs[i] = ctypes.c_char_p(prompt.encode())
outputs = fastllm_lib.response_batch_str_llm_model(self.model, inputs, query_size,
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, False)
responses = []
for i in range(query_size):
response = ctypes.string_at(outputs[i]).decode()
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
return responses, historys
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
inputs = []
inputs_len = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt);
inputs.extend(input)
inputs_len.append(len(input))
outputs = fastllm_lib.response_batch_tokens_llm_model(self.model, query_size,
(ctypes.c_int * len(inputs_len))(*inputs_len),
(ctypes.c_int * len(inputs))(*inputs),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False)
responses = []
for i in range(query_size):
response = ctypes.string_at(outputs[i]).decode()
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
return responses, historys
...@@ -80,6 +80,8 @@ def tofile(exportPath, ...@@ -80,6 +80,8 @@ def tofile(exportPath,
fo.write(struct.pack('i', 2)) fo.write(struct.pack('i', 2))
# 0.1 model info # 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3"
modelInfo = model.config.__dict__ modelInfo = model.config.__dict__
if model.generation_config is not None: if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__) modelInfo.update(model.generation_config.__dict__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment