Commit 52251123 authored by zhouxiang's avatar zhouxiang
Browse files

dtk更新到最新版本,支持k100卡,添加一个apiserver的demo供参考

parent 37f12b90
# Baichuan-13B_fastllm # Baichuan-13B
## 论文 ## 论文
...@@ -33,7 +33,7 @@ Baichuan整体模型基于标准的Transformer结构,采用了和LLaMA一样 ...@@ -33,7 +33,7 @@ Baichuan整体模型基于标准的Transformer结构,采用了和LLaMA一样
在光源可拉取推理的docker镜像,拉取方式如下: 在光源可拉取推理的docker镜像,拉取方式如下:
``` ```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10.1-py38
``` ```
### 容器启动 ### 容器启动
...@@ -43,7 +43,7 @@ docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk ...@@ -43,7 +43,7 @@ docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk
``` ```
# <container_name> 自定义容器名 # <container_name> 自定义容器名
# <project_path> 当前工程所在路径 # <project_path> 当前工程所在路径
docker run -it --name=<container_name> -v <project_path>:/work -w /work --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --cap-add=SYS_PTRACE --shm-size=16G --group-add 39 image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest /bin/bash docker run -it --name=<container_name> -v <project_path>:/work -w /work --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --cap-add=SYS_PTRACE --shm-size=16G --group-add 39 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10.1-py38 /bin/bash
``` ```
### 加载环境 ### 加载环境
...@@ -51,7 +51,7 @@ docker run -it --name=<container_name> -v <project_path>:/work -w /work --device ...@@ -51,7 +51,7 @@ docker run -it --name=<container_name> -v <project_path>:/work -w /work --device
进入容器后执行如下命令,加载运行环境变量 进入容器后执行如下命令,加载运行环境变量
``` ```
source /opt/dtk-23.04/cuda/env.sh source /opt/dtk-23.10/cuda/env.sh
``` ```
### 安装方法 ### 安装方法
......
# coding=utf-8
# Implements API for ChatGLM3-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8100/docs for documents.
import time
import json
import torch
import uvicorn
import argparse
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
#from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from fastllm_pytools import llm
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
global device_map
if torch.cuda.is_available():
for device in device_map:
with torch.cuda.device(device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class Usage(BaseModel):
prompt_tokens: int = None
total_tokens: int = None
completion_tokens: int = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
usage: Usage = None
@app.get("/v1/models", response_model=ModelList)
def list_models():
global model_list
for model in model_list:
ModelCard(id=model)
ModelList.data.append(ModelCard)
return ModelList()
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
def create_chat_completion(request: ChatCompletionRequest):
if request.model not in model_list:
raise HTTPException(status_code=400, detail="Invalid Model Name")
global model
id = "chatcmpl-A"
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
if request.max_length is not None:
max_length = request.max_length
else:
max_length = 1024
if request.temperature is not None:
temperature = request.temperature
else:
temperature = 0.1
if request.top_p is not None:
top_p = request.top_p
else:
top_p = 0.8
prev_messages = request.messages[:-1]
# print(prev_messages)
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(id=id, query=query, history=history, max_length=max_length, top_p = top_p, temperature = temperature, model_id = request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
response = model.response(query=query, history=history, max_length=max_length, top_p = top_p, temperature = temperature)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
prompt_tokens = len(model.tokenizer_encode_string(query))
completion_tokens = len(model.tokenizer_encode_string(response))
usage = Usage(
prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens+completion_tokens,
)
return ChatCompletionResponse(id=id ,model=request.model, choices=[choice_data], object="chat.completion", usage=usage)
def predict(id: str, query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float):
global model
creat_time = int(time.time())
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(id=id, created=creat_time, model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) //pydantic从1.8.0开始不支持dumps_kwags参数,参考https://github.com/THUDM/ChatGLM2-6B/issues/308
yield json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)
for new_response in model.stream_response(query=query, history=history, max_length=max_length, top_p = top_p, temperature = temperature):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_response),
finish_reason=None
)
chunk = ChatCompletionResponse(id=id, created=creat_time, model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(id=id, created=creat_time, model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)
yield '[DONE]'
def args_parser():
parser = argparse.ArgumentParser(description = 'baichuan2_chat_demo')
parser.add_argument('-p', '--path', type = str, default = "/model", help = '模型文件的路径')
parser.add_argument('-g', '--gpus', type = str, default = "0", help = '指定运行的gpu卡,例如“0,1”')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = args_parser()
global model_list
model_list = ["baichuan2-fastllm"]
global device_map
device_map = ["cuda:"+num for num in args.gpus.split(',')]
llm.set_device_map(device_map)
model = llm.model(args.path)
uvicorn.run(app, host='127.0.0.1', port=8100)
import openai
import time
import threading
import queue
from concurrent.futures import ThreadPoolExecutor, as_completed
def jls_extract_def(model, messages, temperature, max_length, stream, index):
openai.api_base = "http://127.0.0.1:8100/v1"
openai.api_key = "none"
output_tokens = 0
ret = ""
t0 = time.time()
result = openai.ChatCompletion.create(model=model,messages=messages, temperature=temperature, max_length=max_length, stream=stream)
for chunk in result:
# print(chunk)
output_tokens += 1
if hasattr(chunk.choices[0].delta, "content"):
if (index == 0):
print(chunk.choices[0].delta.content, end="", flush=True)
ret += chunk.choices[0].delta.content
t1 = time.time()
# print("\ntoken/s: {:.2f}, output_tokens: {}".format(output_tokens/(t1-t0),output_tokens))
result = output_tokens, ret, output_tokens/(t1-t0)
return result
if __name__ == "__main__":
prompt = "满江红全文"
concurrencys = [1]
temperature = 0.1
max_length = 4096
stream = True
prompts = [prompt]
model="baichuan2-fastllm""
messages=[{"role": "user", "content": "你好"}]
pool = ThreadPoolExecutor(max_workers=32)
for i in range(len(concurrencys)):
cur_prompts = prompts * concurrencys[i]
token_count = 0
threads = []
t0 = time.time()
for index, prompt in enumerate(cur_prompts):
messages[0]["content"] = prompt
t = pool.submit(jls_extract_def, model, messages, temperature, max_length, stream, index)
t.index = index
threads.append(t)
for future in as_completed(threads):
result = future.result()
print(future.index)
print(result)
print("\n")
token_count += result[0]
t1 = time.time()
print("\n---------------------------------------------\n")
print("\nconcurrency: {}".format(concurrencys[i]))
print("\ntotal use: {:.2f}".format(t1-t0))
print("\ntoken/s: {:.2f}, token_count: {}".format(token_count/(t1-t0),token_count))
print("\n---------------------------------------------\n")
uvicorn==0.23.2
pydantic==2.5.1
fastapi==0.103.1
sse_starlette
openai==0.28
No preview for this file type
...@@ -26,8 +26,9 @@ def create(model, ...@@ -26,8 +26,9 @@ def create(model,
exit(0); exit(0);
# 0.1 model info # 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2": # if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3" # model.config.model_type = "chatglm3"
# print("model.config.model_type: chatglm3!")
modelInfo = model.config.__dict__ modelInfo = model.config.__dict__
if model.generation_config is not None: if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__) modelInfo.update(model.generation_config.__dict__)
...@@ -50,6 +51,12 @@ def create(model, ...@@ -50,6 +51,12 @@ def create(model,
if modelInfo["chat_format"] == "chatml": if modelInfo["chat_format"] == "chatml":
modelInfo["im_end_id"] = tokenizer.im_end_id modelInfo["im_end_id"] = tokenizer.im_end_id
modelInfo["im_start_id"] = tokenizer.im_start_id modelInfo["im_start_id"] = tokenizer.im_start_id
if (modelInfo["model_type"] == "chatglm" and hasattr(tokenizer, "build_chat_input")):
# chatglm3
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.get_command("<|user|>")) + ">\n");
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.get_command("<|assistant|>")) + ">");
modelInfo["history_sep"] = "";
weight_type_dict = {}; weight_type_dict = {};
......
...@@ -4,10 +4,13 @@ import os; ...@@ -4,10 +4,13 @@ import os;
import threading import threading
from typing import Optional, Tuple, Union, List, Callable, Dict, Any; from typing import Optional, Tuple, Union, List, Callable, Dict, Any;
from copy import deepcopy from copy import deepcopy
import json
import platform import platform
if platform.system() == 'Windows': if platform.system() == 'Windows':
fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "fastllm_tools.dll")) fastllm_lib = ctypes.CDLL(os.path.join(os.path.split(os.path.realpath(__file__))[0], "fastllm_tools.dll"), winmode=0)
elif platform.system() == 'Darwin':
fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.dylib"))
else: else:
fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.so")) fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.so"))
...@@ -22,7 +25,8 @@ fastllm_lib.token_encode_string.restype = ctypes.c_int ...@@ -22,7 +25,8 @@ fastllm_lib.token_encode_string.restype = ctypes.c_int
fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p, fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool] ctypes.c_float, ctypes.c_float, ctypes.c_bool,
ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.launch_response_llm_model.restype = ctypes.c_int fastllm_lib.launch_response_llm_model.restype = ctypes.c_int
fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
...@@ -39,7 +43,8 @@ fastllm_lib.response_str_llm_model.restype = ctypes.POINTER(ctypes.c_char) ...@@ -39,7 +43,8 @@ fastllm_lib.response_str_llm_model.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool] ctypes.c_float, ctypes.c_float, ctypes.c_bool,
ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int
fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
...@@ -59,7 +64,6 @@ fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_ ...@@ -59,7 +64,6 @@ fastllm_lib.add_tokenizer_word_llm_model.argtype = [ctypes.c_int, ctypes.c_char_
fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] fastllm_lib.set_device_map.argtype = [ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int] fastllm_lib.get_llm_model_type.argtype = [ctypes.c_int]
# fastllm_lib.get_llm_model_type.restype = ctypes.c_char_p
fastllm_lib.get_llm_model_type.restype = ctypes.POINTER(ctypes.c_char) fastllm_lib.get_llm_model_type.restype = ctypes.POINTER(ctypes.c_char)
fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, fastllm_lib.response_batch_str_llm_model.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
...@@ -233,19 +237,29 @@ class model: ...@@ -233,19 +237,29 @@ class model:
break break
return buffer_bytes[:result_len] return buffer_bytes[:result_len]
def stop_token_ctypes(self, stop_token_ids):
if stop_token_ids is None:
return 0, None
else:
return ctypes.c_int(len(stop_token_ids)), (ctypes.c_int * len(stop_token_ids))(*stop_token_ids)
def response_logits(self, def response_logits(self,
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
tokenizer = None) -> str: tokenizer = None,
stop_token_ids: List[int] = None,
) -> str:
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
if (tokenizer == None): if (tokenizer == None):
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1), ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True)); ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True),
stop_token_len, stop_token_list);
else: else:
input = tokenizer.encode(prompt); input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
1, False, 1, 1, 1, 1, True); 1, False, 1, 1, 1, 1, True, stop_token_len, stop_token_list);
vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model); vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model);
logits = list(range(vocab_size)) logits = list(range(vocab_size))
array = (ctypes.c_float * (vocab_size * 4))(*logits); array = (ctypes.c_float * (vocab_size * 4))(*logits);
...@@ -258,7 +272,8 @@ class model: ...@@ -258,7 +272,8 @@ class model:
def response(self, def response(self,
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01) -> str: max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
stop_token_ids: List[int] = None) -> str:
ret = ""; ret = "";
for i in self.stream_response(query = query, for i in self.stream_response(query = query,
history = history, history = history,
...@@ -267,7 +282,8 @@ class model: ...@@ -267,7 +282,8 @@ class model:
top_p = top_p, top_k = top_k, top_p = top_p, top_k = top_k,
temperature = temperature, temperature = temperature,
repeat_penalty = repeat_penalty, repeat_penalty = repeat_penalty,
one_by_one = True): one_by_one = True,
stop_token_ids = stop_token_ids):
ret += i; ret += i;
return ret; return ret;
...@@ -275,11 +291,13 @@ class model: ...@@ -275,11 +291,13 @@ class model:
query: str, query: str,
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
one_by_one = True): one_by_one = True, stop_token_ids: List[int] = None):
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False)); ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list);
res = ""; res = "";
ret = b''; ret = b'';
fail_cnt = 0; fail_cnt = 0;
...@@ -310,12 +328,15 @@ class model: ...@@ -310,12 +328,15 @@ class model:
def stream_response_raw(self, def stream_response_raw(self,
input_tokens: List[int], input_tokens: List[int],
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
one_by_one = True one_by_one = True,
stop_token_ids: List[int] = None
): ):
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens), handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
(ctypes.c_int * len(input_tokens))(*input_tokens), (ctypes.c_int * len(input_tokens))(*input_tokens),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False)) ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list)
# 可能遇到长尾char需要多个token才能够生成,所以只返回bytes,string.decode策略交给外部 # 可能遇到长尾char需要多个token才能够生成,所以只返回bytes,string.decode策略交给外部
# 方便统计输出token数量,和控制不完整utf8时候解码的逻辑 # 方便统计输出token数量,和控制不完整utf8时候解码的逻辑
...@@ -335,15 +356,16 @@ class model: ...@@ -335,15 +356,16 @@ class model:
yield total_bytes yield total_bytes
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, stop_token_ids: List[int] = None, **kwargs):
if self.model_type != "chatglm3": if self.model_type != "chatglm3":
if (not(history)): if (not(history)):
history = []; history = [];
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt); input = tokenizer.encode(prompt);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False); False, stop_token_len, stop_token_list);
result = []; result = [];
while True: while True:
...@@ -359,11 +381,11 @@ class model: ...@@ -359,11 +381,11 @@ class model:
history = [] history = []
role = "user" role = "user"
input = self.build_chatglm3_input(tokenizer, query, history=history, role=role) input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
history.append({"role": role, "content": query}) history.append({"role": role, "content": query})
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False); False, stop_token_len, stop_token_list);
tokens = []; tokens = [];
while True: while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle); cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
...@@ -377,15 +399,16 @@ class model: ...@@ -377,15 +399,16 @@ class model:
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None, def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
return_past_key_values = False, **kwargs) -> str: return_past_key_values = False, stop_token_ids: List[int] = None, **kwargs) -> str:
if self.model_type != "chatglm3": if self.model_type != "chatglm3":
if (not(history)): if (not(history)):
history = []; history = [];
prompt = query if self.direct_query else self.get_prompt(query, history); prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt); input = tokenizer.encode(prompt);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False); False, stop_token_len, stop_token_list);
tokens = []; tokens = [];
while True: while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle); cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
...@@ -404,10 +427,11 @@ class model: ...@@ -404,10 +427,11 @@ class model:
role = "user" role = "user"
input = self.build_chatglm3_input(tokenizer, query, history=history, role=role) input = self.build_chatglm3_input(tokenizer, query, history=history, role=role)
history.append({"role": role, "content": query}) history.append({"role": role, "content": query})
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False); False, stop_token_len, stop_token_list);
tokens = []; tokens = [];
while True: while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle); cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
...@@ -517,16 +541,18 @@ class model: ...@@ -517,16 +541,18 @@ class model:
def response_batch(self, querys: List[str], def response_batch(self, querys: List[str],
historys: List[List[Tuple[str, str]]] = None, historys: List[List[Tuple[str, str]]] = None,
max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, max_length: int = 1024, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01,
**kwargs) -> List[str]: stop_token_ids: List[int] = None, **kwargs) -> List[str]:
query_size = len(querys) query_size = len(querys)
if (not(historys)): if (not(historys)):
historys = [[] for _ in range(query_size)] historys = [[] for _ in range(query_size)]
handles = [] handles = []
for i, query in enumerate(querys): for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i]) prompt = query if self.direct_query else self.get_prompt(query, historys[i])
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False)) ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list)
handles.append(handle) handles.append(handle)
responses = [] responses = []
...@@ -560,7 +586,7 @@ class model: ...@@ -560,7 +586,7 @@ class model:
def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024, def chat_batch(self, tokenizer, querys: List[str], historys: List[List[Tuple[str, str]]] = None, max_length: int = 1024,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, **kwargs): do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.01, stop_token_ids: List[int] = None, **kwargs):
query_size = len(querys) query_size = len(querys)
if (not(historys)): if (not(historys)):
historys = [[] for _ in range(query_size)] historys = [[] for _ in range(query_size)]
...@@ -569,9 +595,10 @@ class model: ...@@ -569,9 +595,10 @@ class model:
for i, query in enumerate(querys): for i, query in enumerate(querys):
prompt = query if self.direct_query else self.get_prompt(query, historys[i]) prompt = query if self.direct_query else self.get_prompt(query, historys[i])
input = tokenizer.encode(prompt); input = tokenizer.encode(prompt);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty, max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False); False, stop_token_len, stop_token_list);
handles.append(handle) handles.append(handle)
responses = [] responses = []
...@@ -588,3 +615,6 @@ class model: ...@@ -588,3 +615,6 @@ class model:
return responses, historys return responses, historys
def release_memory(self):
fastllm_lib.release_memory(self.model)
...@@ -80,8 +80,8 @@ def tofile(exportPath, ...@@ -80,8 +80,8 @@ def tofile(exportPath,
fo.write(struct.pack('i', 2)) fo.write(struct.pack('i', 2))
# 0.1 model info # 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2": #if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3" # model.config.model_type = "chatglm3"
modelInfo = model.config.__dict__ modelInfo = model.config.__dict__
if model.generation_config is not None: if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__) modelInfo.update(model.generation_config.__dict__)
...@@ -114,6 +114,13 @@ def tofile(exportPath, ...@@ -114,6 +114,13 @@ def tofile(exportPath,
if modelInfo["chat_format"] == "chatml": if modelInfo["chat_format"] == "chatml":
modelInfo["im_end_id"] = tokenizer.im_end_id modelInfo["im_end_id"] = tokenizer.im_end_id
modelInfo["im_start_id"] = tokenizer.im_start_id modelInfo["im_start_id"] = tokenizer.im_start_id
if (modelInfo["model_type"] == "chatglm" and hasattr(tokenizer, "build_chat_input")):
print("chatglm3")
# chatglm3
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.get_command("<|user|>")) + ">\n");
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.get_command("<|assistant|>")) + ">");
modelInfo["history_sep"] = "";
modelInfo["tokenizer_use_score"] = "1" # 分词带分数 modelInfo["tokenizer_use_score"] = "1" # 分词带分数
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment