Commit 5a6195b7 authored by zhouxiang's avatar zhouxiang
Browse files

修改默认repeat_penalty值,增加batch推理接口

parent 2cec992f
......@@ -258,7 +258,7 @@ class model:
def response(self,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1) -> str:
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01) -> str:
ret = "";
for i in self.stream_response(query = query,
history = history,
......@@ -274,7 +274,7 @@ class model:
def stream_response(self,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
one_by_one = True):
prompt = query if self.direct_query else self.get_prompt(query, history);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
......@@ -309,7 +309,7 @@ class model:
def stream_response_raw(self,
input_tokens: List[int],
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
one_by_one = True
):
handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
......@@ -335,7 +335,7 @@ class model:
yield total_bytes
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, **kwargs):
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
if self.model_type != "chatglm3":
if (not(history)):
history = [];
......@@ -376,7 +376,7 @@ class model:
return response, new_history
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
return_past_key_values = False, **kwargs) -> str:
if self.model_type != "chatglm3":
if (not(history)):
......@@ -463,9 +463,9 @@ class model:
input_ids.extend([tokenizer.get_command("<|assistant|>")])
return input_ids
def response_batch(self, querys: List[str],
def response_batch_raw(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
**kwargs) -> List[str]:
query_size = len(querys)
if (not(historys)):
......@@ -486,32 +486,105 @@ class model:
fastllm_lib.freeCharArray(outputs, query_size)
return responses, historys
def chat_batch_raw(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
inputs = []
inputs_len = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt);
inputs.extend(input)
inputs_len.append(len(input))
outputs = fastllm_lib.response_batch_tokens_llm_model(self.model, query_size,
(ctypes.c_int * len(inputs_len))(*inputs_len),
(ctypes.c_int * len(inputs))(*inputs),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False)
responses = []
for i in range(query_size):
response = ctypes.string_at(outputs[i]).decode()
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
fastllm_lib.freeCharArray(outputs, query_size)
return responses, historys
def response_batch(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
**kwargs) -> List[str]:
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
handles = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False))
handles.append(handle)
responses = []
for i, handle in enumerate(handles):
res = ""
ret = b''
fail_cnt = 0
while True:
# ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
ret_chararry = fastllm_lib.fetch_response_str_llm_model(self.model, handle);
ret += ctypes.string_at(ret_chararry)
fastllm_lib.freeChars(ret_chararry)
cur = ""
try:
cur = ret.decode()
ret = b''
except:
fail_cnt += 1
if (fail_cnt == 20):
break
else:
continue
fail_cnt = 0
if (cur == "<flmeos>"):
break;
res += cur
responses.append(res)
historys[i] = historys[i] + [(querys[i], res)]
return responses, historys
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.1, **kwargs):
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs):
query_size = len(querys)
if (not(historys)):
historys = [[] for _ in range(query_size)]
inputs = []
inputs_len = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt);
inputs.extend(input)
inputs_len.append(len(input))
outputs = fastllm_lib.response_batch_tokens_llm_model(self.model, query_size,
(ctypes.c_int * len(inputs_len))(*inputs_len),
(ctypes.c_int * len(inputs))(*inputs),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False)
responses = []
for i in range(query_size):
response = ctypes.string_at(outputs[i]).decode()
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
fastllm_lib.freeCharArray(outputs, query_size)
return responses, historys
handles = []
for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
handles.append(handle)
responses = []
for i, handle in enumerate(handles):
result = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
if (cur == -1):
break;
result.append(cur);
response = tokenizer.decode(result);
responses.append(response)
historys[i] = historys[i] + [(querys[i], response)]
return responses, historys
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment