Commit f554b7d6 authored by zhouxiang's avatar zhouxiang
Browse files

解决baichuan流式输出不均匀的问题

parent 53e02bdd
......@@ -26,6 +26,8 @@ def create(model,
exit(0);
# 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3"
modelInfo = model.config.__dict__
if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__)
......
......@@ -3,6 +3,7 @@ import math
import os;
import threading
from typing import Optional, Tuple, Union, List, Callable, Dict, Any;
from copy import deepcopy
import platform
if platform.system() == 'Windows':
......@@ -53,6 +54,19 @@ fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_
fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int]
fastllm_lib.get_llm_model_type.restype = ctypes.c_char_p
fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_str_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)
fastllm_lib.response_batch_tokens_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
fastllm_lib.response_batch_tokens_llm_model.restype = ctypes.POINTER(ctypes.c_char_p)
def set_cpu_threads(threads: int):
fastllm_lib.set_cpu_threads(threads);
......@@ -120,6 +134,9 @@ class model:
# 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
self.tokenizer_decode_token_cache = None
self.model_type = fastllm_lib.get_llm_model_type(self.model).decode()
# print("model_type:", self.model_type)
def get_prompt(self,
query: str,
history: List[Tuple[str, str]] = None) -> str:
......@@ -206,8 +223,8 @@ class model:
prompt = query if self.direct_query else self.get_prompt(query, history);
if (tokenizer == None):
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True));
ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True));
else:
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
......@@ -299,49 +316,180 @@ class model:
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
if (not(history)):
history = [];
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
if self.model_type != "chatglm3":
if (not(history)):
history = [];
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
result = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
result.append(cur);
response = tokenizer.decode(result);
history = history + [(query, response)];
return response, history;
result = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
result.append(cur);
response = tokenizer.decode(result);
history = history + [(query, response)];
return response, history;
else:
if history is None:
history = []
role = "user"
input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
history.append({"role": role, "content": query})
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
if response and response[-1] != "�":
response, new_history = self.process_chatglm3_response(response, history)
return response, new_history
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
return_past_key_values = False, **kwargs) -> str:
if (not(history)):
history = [];
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
new_history = history + [(query, response)];
if return_past_key_values:
yield response, new_history, None;
else:
yield response, new_history;
if self.model_type != "chatglm3":
if (not(history)):
history = [];
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
new_history = history + [(query, response)];
if return_past_key_values:
yield response, new_history, None;
else:
yield response, new_history;
else:
if history is None:
history = []
role = "user"
input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
history.append({"role": role, "content": query})
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
tokens.append(cur);
response = tokenizer.decode(tokens);
if response and response[-1] != "�":
response, new_history = self.process_chatglm3_response(response, history)
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
def set_adapter(self, name: str):
fastllm_lib.set_adapter(self.model, str(name).encode())
def disable_adapter(self):
fastllm_lib.disable_adapter(self.model)
def process_chatglm3_response(self, output, history):
content = ""
history = deepcopy(history)
for response in output.split("<|assistant|>"):
metadata, content = response.split("\n", maxsplit=1)
if not metadata.strip():
content = content.strip()
history.append({"role": "assistant", "metadata": metadata, "content": content})
content = content.replace("[[训练时间]]", "2023年")
else:
history.append({"role": "assistant", "metadata": metadata, "content": content})
if history[0]["role"] == "system" and "tools" in history[0]:
content = "\n".join(content.split("\n")[1:-1])
def tool_call(**kwargs):
return kwargs
parameters = eval(content)
content = {"name": metadata.strip(), "parameters": parameters}
else:
content = {"name": metadata.strip(), "content": content}
return content, history
def build_chatglm3_input(self, tokenizer, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(tokenizer.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(tokenizer.build_single_message(role, "", query))
input_ids.extend([tokenizer.get_command("<|assistant|>")])
return input_ids
def response_batch(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
**kwargs) -> List[str]:
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
inputs = (ctypes.c_char_p * query_size)()
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
inputs[i] = ctypes.c_char_p(prompt.encode())
outputs = fastllm_lib.response_batch_str_llm_model(self.model, inputs, query_size,
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, False)
responses = []
for i in range(query_size):
response = ctypes.string_at(outputs[i]).decode()
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
return responses, historys
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
inputs = []
inputs_len = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt);
inputs.extend(input)
inputs_len.append(len(input))
outputs = fastllm_lib.response_batch_tokens_llm_model(self.model, query_size,
(ctypes.c_int * len(inputs_len))(*inputs_len),
(ctypes.c_int * len(inputs))(*inputs),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False)
responses = []
for i in range(query_size):
response = ctypes.string_at(outputs[i]).decode()
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
return responses, historys
......@@ -80,6 +80,8 @@ def tofile(exportPath,
fo.write(struct.pack('i', 2))
# 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3"
modelInfo = model.config.__dict__
if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment