Commit f2ad77b8 authored by change's avatar change
Browse files

推理分支

parent ea4c6c26
......@@ -34,14 +34,23 @@ ChatGLM3-6B同样采用Transformer模型结构:
### Docker(方式一)
推荐使用docker方式运行,提供拉取的docker镜像:
```bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10-fixpy
```
进入docker,安装docker中没有的依赖:
```bash
docker run -dit --network=host --name=chatglm3 --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G -v /opt/hyhal/:/opt/hyhal/:ro --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
# -v 映射目录根据本机自行修改
docker run -dit --network=host --name=chatglm3 --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G -v /opt/hyhal/:/opt/hyhal/:ro -v /物理机目录:/容器内目录 --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10-fixpy
# 启动容器
docker exec -it chatglm3 /bin/bash
# 环境包替换安装
## 执行该命令时,选择Y即可
apt remove python3-blinker
## sentence-transformers和transformers、numpy等库有依赖冲突,需要两步安装
pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
pip install transformers==4.40.0 numpy==1.24.3 nltk==3.9.1 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
cd finetune_demo
pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
......@@ -54,9 +63,9 @@ conda create -n chatglm python=3.10
2. 关于本项目DCU显卡所需的工具包、深度学习库等均可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
- [DTK 24.04.1](https://cancon.hpccube.com:65024/1/main/DTK-24.04.1)
- [Pytorch 2.1.0](https://cancon.hpccube.com:65024/4/main/pytorch)
- [Deepspeed 0.12.3](https://cancon.hpccube.com:65024/directlink/4/deepspeed/DAS1.1/deepspeed-0.12.3+gita724046.abi1.dtk2404.torch2.1.0-cp310-cp310-manylinux_2_31_x86_64.whl)
- [DTK 25.04](https://download.sourcefind.cn:65024/1/main/latest)
- [Pytorch 2.4.1](https://download.sourcefind.cn:65024/4/main/pytorch/DAS1.5)
- [Deepspeed 0.14.2](https://download.sourcefind.cn:65024/4/main/deepspeed/DAS1.5)
Tips:以上dtk驱动、python、deepspeed等工具版本需要严格一一对应。
......@@ -86,20 +95,8 @@ site-packages/transformers/utils/versions.py 文件
```
## 数据集
单轮对话数据以[ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法,该数据集任务为根据输入(content)生成一段广告词(summary),以下为下载地址:
- [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1)
下载处理好的 ADGEN 数据集,将解压后的AdvertiseGen目录放到 [finetune_demo/data](./finetune_demo/data)目录下。数据集目录结构如下:
```
── AdvertiseGen
│   ├── dev.json
│   └── train.json
```
通过以下方式将数据集处理成模型需要的格式:
```bash
cd finetune_demo
python process.py
```
### 模型下载
......@@ -111,73 +108,79 @@ python process.py
## 训练
### SFT微调
#### 单轮对话微调
```bash
cd ./finetune_demo
bash sft.sh
```
注意:请根据自己的需求配置其中的模型路径、数据集路径;batchsize、学习率等参数在./finetune_demo/configs/sft.yaml;
#### 推理验证
对于输入输出格式的微调,可使用 `sft_inf.sh` 进行基本的推理验证。
在完成微调任务之后,我们可以查看到 `output` 文件夹下多了很多个`checkpoint-*`的文件夹,这些文件夹代表了训练的轮数。 我们选择最后一轮的微调权重,并使用inference进行导入。
注意:此时要将hf上下载的原生`tokenizer_config.json``tokenization_chatglm.py` 两个文件放入要待测的 checkpoint 文件夹下,比如./finetune_demo/output/checkpoint-3000/
```bash
cd ./finetune_demo
bash sft_inf.sh
## 推理
使用以下的脚本推理时。都需要指定已经下载到本地的模型路径,如果没有指定,会从huggingface自动下载,指定模型路径方式如下:
### cli_demo
```
### LORA微调
#### 单轮对话微调
```bash
cd ./fintune_demo
bash lora.sh
# 需修改路径
export MODEL_PATH=/models/chatglm3/chatglm3-6b
export TOKENIZER_PATH=/models/chatglm3/chatglm3-6b
# 推理
cd basic_demo
export HIP_VISIBLE_DEVICES=0
export PYTHONWARNINGS="ignore"
python cli_demo.py
```
<div align="center">
<img src="./media/cli_demo.png">
</div>
注意:请根据自己的需求配置其中的模型路径、数据集路径;batchsize、学习率等参数在 ./finetune_demo/configs/lora.yaml;
#### 推理验证
在完成微调任务之后,我们可以查看到 `output` 文件夹下多了很多个`checkpoint-*`的文件夹,这些文件夹代表了训练的轮数。 我们选择最后一轮的微调权重,并使用inference进行导入。
注意:经过LORA微调训练后的checkpoint无需复制原生GLM3的tokenizer文件到其目录下。
```bash
cd ./finetune_demo
bash lora_inf.sh
### web_demo_gradio
```
# 需修改路径
export MODEL_PATH=/models/chatglm3/chatglm3-6b
export TOKENIZER_PATH=/models/chatglm3/chatglm3-6b
# 推理
cd basic_demo
export HIP_VISIBLE_DEVICES=0
export PYTHONWARNINGS="ignore"
python web_demo_gradio.py
```
## Result
### SFT微调
#### 单轮对话微调推理结果
<div align="center">
<img src="./media/result1.jpg">
<img src="./media/gradio.png">
</div>
### LORA微调
#### 单轮对话微调推理结果
### web_demo_streamlit
```
# 需修改路径
export MODEL_PATH=/models/chatglm3/chatglm3-6b
export TOKENIZER_PATH=/models/chatglm3/chatglm3-6b
# 推理
cd basic_demo
export HIP_VISIBLE_DEVICES=0
export PYTHONWARNINGS="ignore"
streamlit run web_demo_streamlit.py
```
<div align="center">
<img src="./media/result2.jpg">
<img src="./media/streamlit.png">
</div>
### api_server
#### 服务端启动命令
```
# 需修改路径
export MODEL_PATH=/models/chatglm3/chatglm3-6b
export TOKENIZER_PATH=/models/chatglm3/chatglm3-6b
export EMBEDDING_PATH=/home/model/BAAI/bge-m3/
# 推理(案例中修改了api_server.py 536行的端口号,项目中的端口未更新,仍为8000端口)
python api_server.py
```
<div align="center">
<img src="./media/server.png">
</div>
#### api调用案例
```
curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}"
```
<div align="center">
<img src="./media/client.png">
</div>
### 精度
......
import os
import platform
from typing import Optional, Union
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
......@@ -9,32 +8,6 @@ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
os_name = platform.system()
clear_command = "cls" if os_name == "Windows" else "clear"
stop_stream = False
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
def build_prompt(history):
prompt = welcome_prompt
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt
def process_model_outputs(outputs, tokenizer):
responses = []
for output in outputs:
response = tokenizer.decode(output, skip_special_tokens=True)
response = response.replace("[gMASK]sop", "").strip()
batch_responses.append(response)
return responses
def batch(
model,
tokenizer,
......
......@@ -35,10 +35,13 @@ from transformers import (
TextIteratorStreamer
)
import socket
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
# MODEL_PATH = 'chatglm3-6b'
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
......@@ -109,10 +112,12 @@ def parse_text(text):
text = "".join(lines)
return text
def predict(history, max_length, top_p, temperature):
def predict(history, max_length, top_p, temperature, system_prompt):
stop = StopOnTokens()
messages = []
if(system_prompt!=""):
messages.append({"role": "system", "content": system_prompt})
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
......@@ -147,31 +152,32 @@ def predict(history, max_length, top_p, temperature):
yield history
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
chatbot = gr.Chatbot()
with gr.Blocks(title="ChatGLM") as demo:
gr.Markdown("## ChatGLM3-6B")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(layout="panel")
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
user_input = gr.Textbox(show_label=False, placeholder="Input to chat...", lines=3, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit")
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
max_length = gr.Slider(0, 32768, value=16384, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
gr.HTML("""<span>System Prompt</span>""")
system_prompt = gr.Textbox(show_label=False, placeholder="System Prompt", lines=6, container=False)
def user(query, history):
return "", history + [[parse_text(query), ""]]
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot
predict, [chatbot, max_length, top_p, temperature, system_prompt], chatbot
)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(server_name="127.0.0.1", server_port=7870, inbrowser=True, share=False)
demo.launch(server_name=socket.gethostbyname(socket.gethostname()), server_port=7870, inbrowser=True, share=False)
......@@ -33,7 +33,7 @@ import time
import tiktoken
import torch
import uvicorn
import json
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
......@@ -44,7 +44,7 @@ from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModel
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
from sentence_transformers import SentenceTransformer
from tools.schema import tool_class, tool_def, tool_param_start_with, tool_define_param_name
from sse_starlette.sse import EventSourceResponse
# Set up limit request time
......@@ -146,7 +146,6 @@ class ChatCompletionRequest(BaseModel):
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
agent: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
......@@ -238,17 +237,16 @@ async def create_chat_completion(request: ChatCompletionRequest):
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
agent=request.agent
tools=request.tools,
)
logger.debug(f"==== request ====\n{gen_params}")
gen_params["tools"] = tool_def if gen_params["agent"] else []
if request.stream:
# Use the stream mode to read the first few characters, if it is not a function call, direct stram output
predict_stream_generator = predict_stream(request.model, gen_params)
output = next(predict_stream_generator)
if not contains_custom_function(output, gen_params["tools"]):
if not contains_custom_function(output):
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
# Obtain the result directly at one time and determine whether tools needs to be called.
......@@ -269,24 +267,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
In this demo, we did not register any tools.
You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.
Similar to the following method:
function_args = json.loads(function_call.arguments)
tool_response = dispatch_tool(tool_name: str, tool_params: dict)
"""
if tool_param_start_with in output:
tool = tool_class.get(function_call.name)
if tool:
this_tool_define_param_name = tool_define_param_name.get(function_call.name)
if this_tool_define_param_name:
tool_param = json.loads(function_call.arguments).get(this_tool_define_param_name)
if tool().parameter_validation(tool_param):
observation = str(tool().run(tool_param))
tool_response = observation
else:
tool_response = "Tool parameter values error, please tell the user about this situation."
else:
tool_response = "Tool parameter is not defined in tools schema, please tell the user about this situation."
else:
tool_response = "No available tools found, please tell the user about this situation."
else:
tool_response = "Tool parameter content error, please tell the user about this situation."
tool_response = ""
if not gen_params.get("messages"):
gen_params["messages"] = []
......@@ -447,7 +431,7 @@ def predict_stream(model_id, gen_params):
if not is_function_call and len(output) > 7:
# Determine whether a function is called
is_function_call = contains_custom_function(output, gen_params["tools"])
is_function_call = contains_custom_function(output)
if is_function_call:
continue
......@@ -528,18 +512,18 @@ async def parse_output_text(model_id: str, value: str):
yield '[DONE]'
def contains_custom_function(value: str, tools: list) -> bool:
def contains_custom_function(value: str) -> bool:
"""
Determine whether 'function_call' according to a special function prefix.
For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_"
[Note] This is not a rigorous judgment method, only for reference.
:param value:
:param tools:
:return:
"""
for tool in tools:
if value and tool["name"] in value:
return True
return value and 'get_' in value
if __name__ == "__main__":
......@@ -549,4 +533,4 @@ if __name__ == "__main__":
# load Embedding
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
uvicorn.run(app, host='0.0.0.0', port=9000, workers=1)
\ No newline at end of file
......@@ -9,16 +9,11 @@ allowing the user to input questions and receive AI answers.
3. This demo is not support for streaming.
"""
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage, SystemMessage, AIMessage
from langchain_community.llms.chatglm3 import ChatGLM3
def initialize_llm_chain(messages: list):
template = "{input}"
prompt = PromptTemplate.from_template(template)
def get_ai_response(messages, user_input):
endpoint_url = "http://127.0.0.1:8000/v1/chat/completions"
llm = ChatGLM3(
endpoint_url=endpoint_url,
......@@ -26,11 +21,7 @@ def initialize_llm_chain(messages: list):
prefix_messages=messages,
top_p=0.9
)
return LLMChain(prompt=prompt, llm=llm)
def get_ai_response(llm_chain, user_message):
ai_response = llm_chain.invoke({"input": user_message})
ai_response = llm.invoke(user_input)
return ai_response
......@@ -42,12 +33,11 @@ def continuous_conversation():
user_input = input("Human (or 'exit' to quit): ")
if user_input.lower() == 'exit':
break
llm_chain = initialize_llm_chain(messages=messages)
ai_response = get_ai_response(llm_chain, user_input)
print("ChatGLM3: ", ai_response["text"])
ai_response = get_ai_response(messages, user_input)
print("ChatGLM3: ", ai_response)
messages += [
HumanMessage(content=user_input),
AIMessage(content=ai_response["text"]),
AIMessage(content=ai_response),
]
......
......@@ -26,7 +26,6 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
else:
if use_tool:
content = "\n".join(content.split("\n")[1:-1])
def tool_call(**kwargs):
return kwargs
......
"""
This script is an example of using the Zhipu API to create various interactions with a ChatGLM3 model. It includes
functions to:
1. Conduct a basic chat session, asking about weather conditions in multiple cities.
2. Initiate a simple chat in Chinese, asking the model to tell a short story.
3. Retrieve and print embeddings for a given text input.
Each function demonstrates a different aspect of the API's capabilities,
showcasing how to make requests and handle responses.
Note: Make sure your Zhipu API key is set as an environment
variable formate as xxx.xxx (just for check, not need a real key).
"""
from zhipuai import ZhipuAI
base_url = "http://127.0.0.1:8000/v1/"
client = ZhipuAI(api_key="EMP.TY", base_url=base_url)
def function_chat():
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
response = client.chat.completions.create(
model="chatglm3_6b",
messages=messages,
tools=tools,
tool_choice="auto",
)
if response:
content = response.choices[0].message.content
print(content)
else:
print("Error:", response.status_code)
def simple_chat(use_stream=True):
messages = [
{
"role": "system",
"content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow "
"the user's instructions carefully. Respond using markdown.",
},
{
"role": "user",
"content": "你好,请你介绍一下chatglm3-6b这个模型"
}
]
response = client.chat.completions.create(
model="chatglm3_",
messages=messages,
stream=use_stream,
max_tokens=256,
temperature=0.8,
top_p=0.8)
if response:
if use_stream:
for chunk in response:
print(chunk.choices[0].delta.content)
else:
content = response.choices[0].message.content
print(content)
else:
print("Error:", response.status_code)
def embedding():
response = client.embeddings.create(
model="bge-large-zh-1.5",
input=["ChatGLM3-6B 是一个大型的中英双语模型。"],
)
embeddings = response.data[0].embedding
print("嵌入完成,维度:", len(embeddings))
if __name__ == "__main__":
simple_chat(use_stream=False)
simple_chat(use_stream=True)
embedding()
function_chat()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment