Commit 4bd96acc authored by lvzhen's avatar lvzhen
Browse files

Update tools_using_demo/cli_demo_tool.py, tools_using_demo/openai_api_demo.py,...

Update tools_using_demo/cli_demo_tool.py, tools_using_demo/openai_api_demo.py, tools_using_demo/README.md, tools_using_demo/README_en.md, tools_using_demo/tool_register.py, tensorrt_llm_demo/README.md, tensorrt_llm_demo/tensorrt_llm_cli_demo.py, resources/cli-demo.png, resources/web-demo2.png, resources/tool_en.png, resources/tool.png, resources/heart.png, resources/wechat.jpg, resources/web-demo.gif, resources/web-demo2.gif, resources/WECHAT.md, resources/code_en.gif, openai_api_demo/api_server.py, openai_api_demo/.env, openai_api_demo/openai_api_request.py, openai_api_demo/docker-compose.yml, openai_api_demo/utils.py, openai_api_demo/zhipu_api_request.py, openai_api_demo/langchain_openai_api.py, langchain_demo/ChatGLM3.py, langchain_demo/main.py, langchain_demo/tools/Calculator.py, langchain_demo/tools/DistanceConversion.py, langchain_demo/tools/Weather.py, Intel_device_demo/README.md, Intel_device_demo/ipex_llm_cpu_demo/api_server.py, Intel_device_demo/ipex_llm_cpu_demo/chatglm3_infer.py, Intel_device_demo/ipex_llm_cpu_demo/chatglm3_web_demo.py, Intel_device_demo/ipex_llm_cpu_demo/openai_api_request.py, Intel_device_demo/ipex_llm_cpu_demo/generate.py, Intel_device_demo/ipex_llm_cpu_demo/utils.py, Intel_device_demo/openvino_demo/openvino_cli_demo.py, Intel_device_demo/openvino_demo/README.md, finetune_demo/lora_finetune.ipynb, finetune_demo/finetune_hf.py, finetune_demo/inference_hf.py, finetune_demo/README.md, finetune_demo/README_en.md, finetune_demo/requirements.txt, finetune_demo/configs/ds_zero_3.json, finetune_demo/configs/ds_zero_2.json, finetune_demo/configs/ptuning_v2.yaml, finetune_demo/configs/lora.yaml, finetune_demo/configs/sft.yaml, composite_demo/assets/emojis.png, composite_demo/assets/demo.png, composite_demo/assets/heart.png, composite_demo/assets/tool.png, composite_demo/.streamlit/config.toml, composite_demo/client.py, composite_demo/conversation.py, composite_demo/README_en.md, composite_demo/main.py, composite_demo/demo_chat.py, composite_demo/README.md, composite_demo/requirements.txt, composite_demo/demo_tool.py, composite_demo/tool_registry.py, composite_demo/demo_ci.py, basic_demo/cli_demo_bad_word_ids.py, basic_demo/cli_demo.py, basic_demo/cli_batch_request_demo.py, basic_demo/web_demo_gradio.py, basic_demo/web_demo_streamlit.py, .github/ISSUE_TEMPLATE/bug_report.yaml, .github/ISSUE_TEMPLATE/feature-request.yaml, .github/PULL_REQUEST_TEMPLATE/pr_template.md, MODEL_LICENSE, .gitignore, DEPLOYMENT.md, DEPLOYMENT_en.md, LICENSE, PROMPT.md, README_en.md, requirements.txt, README.md, PROMPT_en.md, update_requirements.sh files
parent d0572507
This diff is collapsed.
import os
import platform
from typing import Optional, Union
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
os_name = platform.system()
clear_command = "cls" if os_name == "Windows" else "clear"
stop_stream = False
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
def build_prompt(history):
prompt = welcome_prompt
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt
def process_model_outputs(outputs, tokenizer):
responses = []
for output in outputs:
response = tokenizer.decode(output, skip_special_tokens=True)
response = response.replace("[gMASK]sop", "").strip()
batch_responses.append(response)
return responses
def batch(
model,
tokenizer,
prompts: Union[str, list[str]],
max_length: int = 8192,
num_beams: int = 1,
do_sample: bool = True,
top_p: float = 0.8,
temperature: float = 0.8,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
):
tokenizer.encode_special_tokens = True
if isinstance(prompts, str):
prompts = [prompts]
batched_inputs = tokenizer(prompts, return_tensors="pt", padding="longest")
batched_inputs = batched_inputs.to(model.device)
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|assistant|>"),
]
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"eos_token_id": eos_token_id,
}
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
batched_response = []
for input_ids, output_ids in zip(batched_inputs.input_ids, batched_outputs):
decoded_text = tokenizer.decode(output_ids[len(input_ids):])
batched_response.append(decoded_text.strip())
return batched_response
def main(batch_queries):
gen_kwargs = {
"max_length": 2048,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"num_beams": 1,
}
batch_responses = batch(model, tokenizer, batch_queries, **gen_kwargs)
return batch_responses
if __name__ == "__main__":
batch_queries = [
"<|user|>\n讲个故事\n<|assistant|>",
"<|user|>\n讲个爱情故事\n<|assistant|>",
"<|user|>\n讲个开心故事\n<|assistant|>",
"<|user|>\n讲个睡前故事\n<|assistant|>",
"<|user|>\n讲个励志的故事\n<|assistant|>",
"<|user|>\n讲个少壮不努力的故事\n<|assistant|>",
"<|user|>\n讲个青春校园恋爱故事\n<|assistant|>",
"<|user|>\n讲个工作故事\n<|assistant|>",
"<|user|>\n讲个旅游的故事\n<|assistant|>",
]
batch_responses = main(batch_queries)
for response in batch_responses:
print("=" * 10)
print(response)
import os
import platform
from transformers import AutoTokenizer, AutoModel
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() to use int4 model
# must use cuda to load int4 model
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
def build_prompt(history):
prompt = welcome_prompt
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt
def main():
past_key_values, history = None, []
global stop_stream
print(welcome_prompt)
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
past_key_values, history = None, []
os.system(clear_command)
print(welcome_prompt)
continue
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
current_length = len(response)
print("")
if __name__ == "__main__":
main()
"""
This script demonstrates how to use the `bad_words_ids` argument in the context of a conversational AI model to filter out unwanted words or phrases from the model's responses. It's designed to showcase a fundamental method of content moderation within AI-generated text, particularly useful in scenarios where maintaining the decorum of the conversation is essential.
Usage:
- Interact with the model by typing queries. The model will generate responses while avoiding the specified bad words.
- Use 'clear' to clear the conversation history and 'stop' to exit the program.
Requirements:
- The script requires the Transformers library and an appropriate model checkpoint.
Note: The `bad_words_ids` feature is an essential tool for controlling the output of language models, particularly in user-facing applications where content moderation is crucial.
"""
import os
import platform
from transformers import AutoTokenizer, AutoModel
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
# probability tensor contains either `inf`, `nan` or element < 0
bad_words = ["你好", "ChatGLM"]
bad_word_ids = [tokenizer.encode(bad_word, add_special_tokens=False) for bad_word in bad_words]
def build_prompt(history):
prompt = welcome_prompt
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt
def main():
past_key_values, history = None, []
global stop_stream
print(welcome_prompt)
while True:
query = input("\n用户:")
if query.strip().lower() == "stop":
break
if query.strip().lower() == "clear":
past_key_values, history = None, []
os.system(clear_command)
print(welcome_prompt)
continue
# Attempt to generate a response
try:
print("\nChatGLM:", end="")
current_length = 0
response_generated = False
for response, history, past_key_values in model.stream_chat(
tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True,
bad_words_ids=bad_word_ids # assuming this is implemented correctly
):
response_generated = True
# Check if the response contains any bad words
if any(bad_word in response for bad_word in bad_words):
print("我的回答涉嫌了 bad word")
break # Break the loop if a bad word is detected
# Otherwise, print the generated response
print(response[current_length:], end="", flush=True)
current_length = len(response)
if not response_generated:
print("没有生成任何回答。")
except RuntimeError as e:
print(f"生成文本时发生错误:{e},这可能是涉及到设定的敏感词汇")
print("")
if __name__ == "__main__":
main()
\ No newline at end of file
"""
This script creates an interactive web demo for the ChatGLM3-6B model using Gradio,
a Python library for building quick and easy UI components for machine learning models.
It's designed to showcase the capabilities of the ChatGLM3-6B model in a user-friendly interface,
allowing users to interact with the model through a chat-like interface.
Usage:
- Run the script to start the Gradio web server.
- Interact with the model by typing questions and receiving responses.
Requirements:
- Gradio (required for 4.13.0 and later, 3.x is not support now) should be installed.
Note: The script includes a modification to the Chatbot's postprocess method to handle markdown to HTML conversion,
ensuring that the chat interface displays formatted text correctly.
"""
import os
import gradio as gr
import torch
from threading import Thread
from typing import Union, Annotated
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [0, 2]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def predict(history, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
print("\n\n====conversation====\n", messages)
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token != '':
history[-1][1] += new_token
yield history
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
def user(query, history):
return "", history + [[parse_text(query), ""]]
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(server_name="127.0.0.1", server_port=7870, inbrowser=True, share=False)
"""
This script is a simple web demo based on Streamlit, showcasing the use of the ChatGLM3-6B model. For a more comprehensive web demo,
it is recommended to use 'composite_demo'.
Usage:
- Run the script using Streamlit: `streamlit run web_demo_streamlit.py`
- Adjust the model parameters from the sidebar.
- Enter questions in the chat input box and interact with the ChatGLM3-6B model.
Note: Ensure 'streamlit' and 'transformers' libraries are installed and the required model checkpoints are available.
"""
import os
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
st.set_page_config(
page_title="ChatGLM3-6B Streamlit Simple Demo",
page_icon=":robot:",
layout="wide"
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
return tokenizer, model
# 加载Chatglm3的model和tokenizer
tokenizer, model = get_model()
if "history" not in st.session_state:
st.session_state.history = []
if "past_key_values" not in st.session_state:
st.session_state.past_key_values = None
max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01)
buttonClean = st.sidebar.button("清理会话历史", key="clean")
if buttonClean:
st.session_state.history = []
st.session_state.past_key_values = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
st.rerun()
for i, message in enumerate(st.session_state.history):
if message["role"] == "user":
with st.chat_message(name="user", avatar="user"):
st.markdown(message["content"])
else:
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown(message["content"])
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
prompt_text = st.chat_input("请输入您的问题")
if prompt_text:
input_placeholder.markdown(prompt_text)
history = st.session_state.history
past_key_values = st.session_state.past_key_values
for response, history, past_key_values in model.stream_chat(
tokenizer,
prompt_text,
history,
past_key_values=past_key_values,
max_length=max_length,
top_p=top_p,
temperature=temperature,
return_past_key_values=True,
):
message_placeholder.markdown(response)
st.session_state.history = history
st.session_state.past_key_values = past_key_values
[theme]
font = "monospace"
\ No newline at end of file
# ChatGLM3 Web Demo
![Demo webpage](assets/demo.png)
## 安装
我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。
执行以下命令新建一个 conda 环境并安装所需依赖:
```bash
conda create -n chatglm3-demo python=3.10
conda activate chatglm3-demo
pip install -r requirements.txt
```
请注意,本项目需要 Python 3.10 或更高版本。
此外,使用 Code Interpreter 还需要安装 Jupyter 内核:
```bash
ipython kernel install --name chatglm3-demo --user
```
## 运行
运行以下命令在本地加载模型并启动 demo:
```bash
streamlit run main.py
```
之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。
如果已经在本地下载了模型,可以通过 `export MODEL_PATH=/path/to/model` 来指定从本地加载模型。如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
## 使用
ChatGLM3 Demo 拥有三种模式:
- Chat: 对话模式,在此模式下可以与模型进行对话。
- Tool: 工具模式,模型除了对话外,还可以通过工具进行其他操作。
- Code Interpreter: 代码解释器模式,模型可以在一个 Jupyter 环境中执行代码并获取结果,以完成复杂任务。
### 对话模式
对话模式下,用户可以直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为。例如
![The model responses following system prompt](assets/emojis.png)
### 工具模式
可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring 即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。
例如,`get_weather` 工具的注册如下:
```python
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
```
![The model uses tool to query the weather of pairs.](assets/tool.png)
此外,你也可以在页面中通过 `Manual mode` 进入手动模式,在这一模式下你可以通过 YAML 来直接指定工具列表,但你需要手动将工具的输出反馈给模型。
### 代码解释器模式
由于拥有代码执行环境,此模式下的模型能够执行更为复杂的任务,例如绘制图表、执行符号运算等等。模型会根据对任务完成情况的理解自动地连续执行多个代码块,直到任务完成。因此,在这一模式下,你只需要指明希望模型执行的任务即可。
例如,我们可以让 ChatGLM3 画一个爱心:
![The code interpreter draws a heart according to the user's instructions.](assets/heart.png)
### 额外技巧
- 在模型生成文本时,可以通过页面右上角的 `Stop` 按钮进行打断。
- 刷新页面即可清空对话记录。
# Enjoy!
\ No newline at end of file
# ChatGLM3 Web Demo
![Demo webpage](assets/demo.png)
## Installation
We recommend managing environments through [Conda](https://docs.conda.io/en/latest/).
Execute the following commands to create a new conda environment and install the necessary dependencies:
```bash
conda create -n chatglm3-demo python=3.10
conda activate chatglm3-demo
pip install -r requirements.txt
```
Please note that this project requires Python 3.10 or higher.
Additionally, installing the Jupyter kernel is required for using the Code Interpreter:
```bash
ipython kernel install --name chatglm3-demo --user
```
## Execution
Run the following command to load the model locally and start the demo:
```bash
streamlit run main.py
```
Afterward, the address of the demo can be seen from the command line; click to access. The first visit requires the download and loading of the model, which may take some time.
If the model has already been downloaded locally, you can specify to load the model locally through `export MODEL_PATH=/path/to/model`. If you need to customize the Jupyter kernel, you can specify it through `export IPYKERNEL=<kernel_name>`.
## Usage
ChatGLM3 Demo has three modes:
- Chat: Dialogue mode, where you can interact with the model.
- Tool: Tool mode, where the model, in addition to dialogue, can perform other operations through tools.
- Code Interpreter: Code interpreter mode, where the model can execute code in a Jupyter environment and obtain results to complete complex tasks.
### Dialogue Mode
In dialogue mode, users can directly modify parameters such as top_p, temperature, System Prompt in the sidebar to adjust the behavior of the model. For example,
![The model responses following system prompt](assets/emojis.png)
### Tool Mode
You can enhance the model's capabilities by registering new tools in `tool_registry.py`. Just use the `@register_tool` decorator to complete the registration. For tool declarations, the function name is the name of the tool, and the function docstring is the description of the tool; for tool parameters, use `Annotated[typ: type, description: str, required: bool]` to annotate the type, description, and whether it is necessary of the parameters.
For example, the registration of the `get_weather` tool is as follows:
```python
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
```
![The model uses tool to query the weather of pairs.](assets/tool.png)
Additionally, you can enter the manual mode through `Manual mode` on the page. In this mode, you can directly specify the tool list through YAML, but you need to manually feed back the tool's output to the model.
### Code Interpreter Mode
Due to having a code execution environment, the model in this mode can perform more complex tasks, such as drawing charts, performing symbolic operations, etc. The model will automatically execute multiple code blocks in succession based on its understanding of the task completion status until the task is completed. Therefore, in this mode, you only need to specify the task you want the model to perform.
For example, we can ask ChatGLM3 to draw a heart:
![The code interpreter draws a heart according to the user's instructions.](assets/heart.png)
### Additional Tips
- While the model is generating text, it can be interrupted by the `Stop` button at the top right corner of the page.
- Refreshing the page will clear the dialogue history.
# Enjoy!
\ No newline at end of file
from __future__ import annotations
import os
import streamlit as st
import torch
from collections.abc import Iterable
from typing import Any, Protocol
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from conversation import Conversation
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
PT_PATH = os.environ.get('PT_PATH', None)
PRE_SEQ_LEN = int(os.environ.get("PRE_SEQ_LEN", 128))
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
@st.cache_resource
def get_client() -> Client:
client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH)
return client
class Client(Protocol):
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
...
def stream_chat(
self, tokenizer, query: str,
history: list[tuple[str, str]] = None,
role: str = "user",
past_key_values=None,
max_new_tokens: int = 256,
do_sample=True, top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
length_penalty=1.0, num_beams=1,
logits_processor=None,
return_past_key_values=False,
**kwargs
):
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
if history is None:
history = []
print("\n== Input ==\n", query)
print("\n==History==\n", history)
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
"num_beams": num_beams,
**kwargs
}
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
inputs = tokenizer.build_chat_input(query, role=role)
inputs = inputs.to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
input_sequence_length = inputs['input_ids'].shape[1]
if input_sequence_length + max_new_tokens >= self.config.seq_length:
yield "Current input sequence length {} plus max_new_tokens {} is too long. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
input_sequence_length, max_new_tokens, self.config.seq_length
), history
return
if input_sequence_length > self.config.seq_length:
yield "Current input sequence length {} exceeds maximum model sequence length {}. Unable to generate tokens.".format(
input_sequence_length, self.config.seq_length
), history
return
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
**gen_kwargs):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
new_history = history
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
class HFClient(Client):
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
if pt_checkpoint is not None and os.path.exists(pt_checkpoint):
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=True,
pre_seq_len=PRE_SEQ_LEN
)
self.model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
config=config,
device_map="auto").eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model
# must use cuda to load int4 model
prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
self.model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model
# must use cuda to load int4 model
def generate_stream(
self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
chat_history = [{
'role': 'system',
'content': system if not tools else TOOL_PROMPT,
}]
if tools:
chat_history[0]['tools'] = tools
for conversation in history[:-1]:
chat_history.append({
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
'content': conversation.content,
})
query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
text = ''
for new_text, _ in stream_chat(
self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
word = new_text.removeprefix(text)
word_stripped = word.strip()
text = new_text
yield TextGenerationStreamResponse(
generated_text=text,
token=Token(
id=0,
logprob=0,
text=word,
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
)
)
from dataclasses import dataclass
from enum import auto, Enum
import json
from PIL.Image import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n'
class Role(Enum):
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
INTERPRETER = auto()
OBSERVATION = auto()
def __str__(self):
match self:
case Role.SYSTEM:
return "<|system|>"
case Role.USER:
return "<|user|>"
case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:
return "<|assistant|>"
case Role.OBSERVATION:
return "<|observation|>"
# Get the message block for the given role
def get_message(self):
# Compare by value here, because the enum object in the session state
# is not the same as the enum cases here, due to streamlit's rerunning
# behavior.
match self.value:
case Role.SYSTEM.value:
return
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
case Role.INTERPRETER.value:
return st.chat_message(name="interpreter", avatar="assistant")
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="user")
case _:
st.error(f'Unexpected role: {self}')
@dataclass
class Conversation:
role: Role
content: str
tool: str | None = None
image: Image | None = None
def __str__(self) -> str:
print(self.role, self.content, self.tool)
match self.role:
case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:
return f'{self.role}\n{self.content}'
case Role.TOOL:
return f'{self.role}{self.tool}\n{self.content}'
case Role.INTERPRETER:
return f'{self.role}interpreter\n{self.content}'
# Human readable format
def get_text(self) -> str:
text = postprocess_text(self.content)
match self.role.value:
case Role.TOOL.value:
text = f'Calling tool `{self.tool}`:\n\n{text}'
case Role.INTERPRETER.value:
text = f'{text}'
case Role.OBSERVATION.value:
text = f'Observation:\n```\n{text}\n```'
return text
# Display as a markdown block
def show(self, placeholder: DeltaGenerator | None=None) -> str:
if placeholder:
message = placeholder
else:
message = self.role.get_message()
if self.image:
message.image(self.image)
else:
text = self.get_text()
message.markdown(text)
def preprocess_text(
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
) -> str:
if tools:
tools = json.dumps(tools, indent=4, ensure_ascii=False)
prompt = f"{Role.SYSTEM}\n"
prompt += system if not tools else TOOL_PROMPT
if tools:
tools = json.loads(tools)
prompt += json.dumps(tools, ensure_ascii=False)
for conversation in history:
prompt += f'{conversation}'
prompt += f'{Role.ASSISTANT}\n'
return prompt
def postprocess_text(text: str) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
return text.strip()
\ No newline at end of file
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
client = get_client()
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(
prompt_text: str,
system_prompt: str,
top_p: float = 0.8,
temperature: float = 0.95,
repetition_penalty: float = 1.0,
max_new_tokens: int = 1024,
retry: bool = False
):
placeholder = st.empty()
with placeholder.container():
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
del history[last_user_conversation_idx:]
if prompt_text:
prompt_text = prompt_text.strip()
append_conversation(Conversation(Role.USER, prompt_text), history)
placeholder = st.empty()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
for response in client.generate_stream(
system_prompt,
tools=None,
history=history,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(Role.USER)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("\n==Output:==\n", output_text)
match token.text.strip():
case '<|user|>':
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
break
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
\ No newline at end of file
import base64
from io import BytesIO
import os
from pprint import pprint
import queue
import re
from subprocess import PIPE
import jupyter_client
from PIL import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo')
SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。'
client = get_client()
class CodeKernel(object):
def __init__(self,
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
self.python_path = python_path
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose
if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_manager.connection_file))
if verbose:
pprint(self.kernel_manager.get_connection_info())
# Initialize the code kernel
self.kernel = self.kernel_manager.blocking_client()
# self.kernel.load_connection_file()
self.kernel.start_channels()
print("Code kernel started.")
def execute(self, code):
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
while True:
msg_out = io_msg_content
### Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
break
except queue.Empty:
break
return shell_msg, msg_out
except Exception as e:
print(e)
return None
def execute_interactive(self, code, verbose=False):
shell_msg = self.kernel.execute_interactive(code)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def inspect(self, code, verbose=False):
msg_id = self.kernel.inspect(code)
shell_msg = self.kernel.get_shell_msg(timeout=30)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error':
try:
error_msg = msg['content']['traceback']
except:
try:
error_msg = msg['content']['traceback'][-1].strip()
except:
error_msg = "Traceback Error"
if verbose:
print("Error: ", error_msg)
return error_msg
return None
def check_msg(self, msg, verbose=False):
status = msg['content']['status']
if status == 'ok':
if verbose:
print("Execution succeeded.")
elif status == 'error':
for line in msg['content']['traceback']:
if verbose:
print(line)
def shutdown(self):
# Shutdown the backend kernel
self.kernel_manager.shutdown_kernel()
print("Backend kernel shutdown.")
# Shutdown the code kernel
self.kernel.shutdown()
print("Code kernel shutdown.")
def restart(self):
# Restart the backend kernel
self.kernel_manager.restart_kernel()
# print("Backend kernel restarted.")
def interrupt(self):
# Interrupt the backend kernel
self.kernel_manager.interrupt_kernel()
# print("Backend kernel interrupted.")
def is_alive(self):
return self.kernel.is_alive()
def b64_2_img(data):
buff = BytesIO(base64.b64decode(data))
return Image.open(buff)
def clean_ansi_codes(input_string):
ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', input_string)
def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
code = code.replace("<|assistant|>interpreter", "")
code = code.replace("<|assistant|>", "")
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
res_type = "text"
res = output['text']
elif 'data' in output:
for key in output['data']:
if 'text/plain' in key:
res_type = "text"
res = output['data'][key]
elif 'image/png' in key:
res_type = "image"
res = output['data'][key]
break
if res_type == "image":
return res_type, b64_2_img(res)
elif res_type == "text" or res_type == "traceback":
res = res
return res_type, res
@st.cache_resource
def get_kernel():
kernel = CodeKernel()
return kernel
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(
prompt_text: str,
top_p: float = 0.2,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
max_new_tokens: int = 1024,
truncate_length: int = 1024,
retry: bool = False
):
if 'ci_history' not in st.session_state:
st.session_state.ci_history = []
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
del history[last_user_conversation_idx:]
if prompt_text:
prompt_text = prompt_text.strip()
role = Role.USER
append_conversation(Conversation(role, prompt_text), history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=SYSTEM_PROMPT,
tools=None,
history=history,
do_sample=True,
max_new_token=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("\n==Output:==\n", output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="interpreter", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
continue
case '<|observation|>':
code = extract_code(output_text)
display_text = output_text.split('interpreter')[-1].strip()
append_conversation(Conversation(
Role.INTERPRETER,
postprocess_text(display_text),
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
output_text = ''
with markdown_placeholder:
with st.spinner('Executing code...'):
try:
res_type, res = execute(code, get_kernel())
except Exception as e:
st.error(f'Error when executing code: {e}')
return
print("Received:", res_type, res)
if truncate_length:
if res_type == 'text' and len(res) > truncate_length:
res = res[:truncate_length] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION,
'[Image]' if res_type == 'image' else postprocess_text(res),
tool=None,
image=res if res_type == 'image' else None,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
break
output_text += response.token.text
display_text = output_text.split('interpreter')[-1].strip()
markdown_placeholder.markdown(postprocess_text(display_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
else:
st.session_state.chat_history = []
import re
import yaml
from yaml import YAMLError
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
EXAMPLE_TOOL = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
}
client = get_client()
def tool_call(*args, **kwargs) -> dict:
print("=== Tool call===")
print(args)
print(kwargs)
st.session_state.calling_tool = True
return kwargs
def yaml_to_dict(tools: str) -> list[dict] | None:
try:
return yaml.safe_load(tools)
except YAMLError:
return None
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
print(matches)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(
prompt_text: str,
top_p: float = 0.2,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
max_new_tokens: int = 1024,
truncate_length: int = 1024,
retry: bool = False
):
manual_mode = st.toggle('Manual mode',
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
if manual_mode:
with st.expander('Tools'):
tools = st.text_area(
'Define your tools in YAML format here:',
yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
height=400,
)
tools = yaml_to_dict(tools)
if not tools:
st.error('YAML format error in tools definition')
else:
tools = get_tools()
if 'tool_history' not in st.session_state:
st.session_state.tool_history = []
if 'calling_tool' not in st.session_state:
st.session_state.calling_tool = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
del history[last_user_conversation_idx:]
if prompt_text:
prompt_text = prompt_text.strip()
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
append_conversation(Conversation(role, prompt_text), history)
st.session_state.calling_tool = False
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=None,
tools=tools,
history=history,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("\n==Output:==\n", output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
output_text = ''
message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
continue
case '<|observation|>':
tool, *call_args_text = output_text.strip().split('\n')
call_args_text = '\n'.join(call_args_text)
append_conversation(Conversation(
Role.TOOL,
postprocess_text(output_text),
tool,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
try:
code = extract_code(call_args_text)
args = eval(code, {'tool_call': tool_call}, {})
except:
st.error('Failed to parse tool call')
return
output_text = ''
if manual_mode:
st.info('Please provide tool call results below:')
return
else:
with markdown_placeholder:
with st.spinner(f'Calling tool {tool}...'):
observation = dispatch_tool(tool, args)
if len(observation) > truncate_length:
observation = observation[:truncate_length] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION, observation
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
st.session_state.calling_tool = False
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
return
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
import streamlit as st
st.set_page_config(
page_title="ChatGLM3 Demo",
page_icon=":robot:",
layout='centered',
initial_sidebar_state='expanded',
)
import demo_chat, demo_ci, demo_tool
from enum import Enum
DEFAULT_SYSTEM_PROMPT = '''
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
'''.strip()
# Set the title of the demo
st.title("ChatGLM3 Demo")
# Add your custom text here, with smaller font size
st.markdown(
"<sub>智谱AI 公开在线技术文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof </sub> \n\n <sub> 更多 ChatGLM3-6B 的使用方法请参考文档。</sub>",
unsafe_allow_html=True)
class Mode(str, Enum):
CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑‍💻 Code Interpreter'
with st.sidebar:
top_p = st.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.slider(
'temperature', 0.0, 1.5, 0.95, step=0.01
)
repetition_penalty = st.slider(
'repetition_penalty', 0.0, 2.0, 1.1, step=0.01
)
max_new_token = st.slider(
'Output length', 5, 32000, 256, step=1
)
cols = st.columns(2)
export_btn = cols[0]
clear_history = cols[1].button("Clear History", use_container_width=True)
retry = export_btn.button("Retry", use_container_width=True)
system_prompt = st.text_area(
label="System Prompt (Only for chat mode)",
height=300,
value=DEFAULT_SYSTEM_PROMPT,
)
prompt_text = st.chat_input(
'Chat with ChatGLM3!',
key='chat_input',
)
tab = st.radio(
'Mode',
[mode.value for mode in Mode],
horizontal=True,
label_visibility='hidden',
)
if clear_history or retry:
prompt_text = ""
match tab:
case Mode.CHAT:
demo_chat.main(
retry=retry,
top_p=top_p,
temperature=temperature,
prompt_text=prompt_text,
system_prompt=system_prompt,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_token
)
case Mode.TOOL:
demo_tool.main(
retry=retry,
top_p=top_p,
temperature=temperature,
prompt_text=prompt_text,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_token,
truncate_length=1024)
case Mode.CI:
demo_ci.main(
retry=retry,
top_p=top_p,
temperature=temperature,
prompt_text=prompt_text,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_token,
truncate_length=1024)
case _:
st.error(f'Unexpected tab: {tab}')
huggingface_hub>=0.19.4
pillow>=10.1.0
pyyaml>=6.0.1
requests>=2.31.0
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment