Commit 5a6195b7 authored by zhouxiang's avatar zhouxiang
Browse files

修改默认repeat_penalty值,增加batch推理接口

parent 2cec992f
...@@ -258,7 +258,7 @@ class model: ...@@ -258,7 +258,7 @@ class model:
def response(self, def response(self,
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1) -> str: max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01) -> str:
ret = ""; ret = "";
for i in self.stream_response(query = query, for i in self.stream_response(query = query,
history = history, history = history,
...@@ -274,7 +274,7 @@ class model: ...@@ -274,7 +274,7 @@ class model:
def stream_response(self, def stream_response(self,
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
one_by_one = True): one_by_one = True):
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
...@@ -309,7 +309,7 @@ class model: ...@@ -309,7 +309,7 @@ class model:
def stream_response_raw(self, def stream_response_raw(self,
input_tokens: List[int], input_tokens: List[int],
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
one_by_one = True one_by_one = True
): ):
handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens), handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
...@@ -335,7 +335,7 @@ class model: ...@@ -335,7 +335,7 @@ class model:
yield total_bytes yield total_bytes
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
if self.model_type != "chatglm3": if self.model_type != "chatglm3":
if (not(history)): if (not(history)):
history = []; history = [];
...@@ -376,7 +376,7 @@ class model: ...@@ -376,7 +376,7 @@ class model:
return response, new_history return response, new_history
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None, def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
return_past_key_values = False, **kwargs) -> str: return_past_key_values = False, **kwargs) -> str:
if self.model_type != "chatglm3": if self.model_type != "chatglm3":
if (not(history)): if (not(history)):
...@@ -463,9 +463,9 @@ class model: ...@@ -463,9 +463,9 @@ class model:
input_ids.extend([tokenizer.get_command("<|assistant|>")]) input_ids.extend([tokenizer.get_command("<|assistant|>")])
return input_ids return input_ids
def response_batch(self, querys: List[str], def response_batch_raw(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None, historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
**kwargs) -> List[str]: **kwargs) -> List[str]:
query_size = len(querys) query_size = len(querys)
if (not(historys)): if (not(historys)):
...@@ -486,8 +486,8 @@ class model: ...@@ -486,8 +486,8 @@ class model:
fastllm_lib.freeCharArray(outputs, query_size) fastllm_lib.freeCharArray(outputs, query_size)
return responses, historys return responses, historys
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024, def chat_batch_raw(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
query_size = len(querys) query_size = len(querys)
if (not(historys)): if (not(historys)):
historys = [[] for _ in range(query_size)] historys = [[] for _ in range(query_size)]
...@@ -514,4 +514,77 @@ class model: ...@@ -514,4 +514,77 @@ class model:
fastllm_lib.freeCharArray(outputs, query_size) fastllm_lib.freeCharArray(outputs, query_size)
return responses, historys return responses, historys
def response_batch(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
**kwargs) -> List[str]:
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
handles = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False))
handles.append(handle)
responses = []
for i, handle in enumerate(handles):
res = ""
ret = b''
fail_cnt = 0
while True:
# ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
ret_chararry = fastllm_lib.fetch_response_str_llm_model(self.model, handle);
ret += ctypes.string_at(ret_chararry)
fastllm_lib.freeChars(ret_chararry)
cur = ""
try:
cur = ret.decode()
ret = b''
except:
fail_cnt += 1
if (fail_cnt == 20):
break
else:
continue
fail_cnt = 0
if (cur == "<flmeos>"):
break;
res += cur
responses.append(res)
historys[i] = historys[i] + [(querys[i], res)]
return responses, historys
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
handles = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
handles.append(handle)
responses = []
for i, handle in enumerate(handles):
result = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
result.append(cur);
response = tokenizer.decode(result);
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
return responses, historys
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment