Commit 0e1045f0 authored by lvzhen's avatar lvzhen
Browse files

Revert "Merge branch 'master' into 'master'"

This reverts merge request !2
parent 467ec853
import os
import platform
from transformers import AutoTokenizer, AutoModel
import torch
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
MODEL_PATH = os.environ.get('MODEL_PATH', '../../chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# for Mac Computer like M1
# You Need Use Pytorch compiled with Metal
# DEVICE = 'mps'
# for AMD gpu likes MI100 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with ROCm
# DEVICE = 'cuda'
# for Intel gpu likes A770 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with oneDNN and install intel-extension-for-pytorch
# import intel_extension_for_pytorch as ipex
# DEVICE = 'xpu'
# for Moore Threads gpu like MTT S80 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with Musa
# DEVICE = 'musa'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() to use int4 model
# must use cuda to load int4 model
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
......
"""
This script demonstrates how to use the `bad_words_ids` argument in the context of a conversational AI model to filter out unwanted words or phrases from the model's responses. It's designed to showcase a fundamental method of content moderation within AI-generated text, particularly useful in scenarios where maintaining the decorum of the conversation is essential.
Usage:
- Interact with the model by typing queries. The model will generate responses while avoiding the specified bad words.
- Use 'clear' to clear the conversation history and 'stop' to exit the program.
Requirements:
- The script requires the Transformers library and an appropriate model checkpoint.
Note: The `bad_words_ids` feature is an essential tool for controlling the output of language models, particularly in user-facing applications where content moderation is crucial.
This script demonstrates how to use the `bad_words_ids` argument to filter out.
"""
import os
import platform
from transformers import AutoTokenizer, AutoModel
import torch
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
MODEL_PATH = os.environ.get('MODEL_PATH', '../../chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
......@@ -28,9 +22,11 @@ stop_stream = False
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
# 定义不希望出现的词汇, 你可以自定义, 在这个例子中,如果模型回答包含 "你好" 或 "ChatGLM",则会出现这个报错
# probability tensor contains either `inf`, `nan` or element < 0
bad_words = ["你好", "ChatGLM"]
# 将这些词汇转换为token ID列表,每个短语是一个子列表
bad_word_ids = [tokenizer.encode(bad_word, add_special_tokens=False) for bad_word in bad_words]
......@@ -70,7 +66,7 @@ def main():
response_generated = True
# Check if the response contains any bad words
if any(bad_word in response for bad_word in bad_words):
print("我的回答涉嫌了 bad word")
print("我的回答涉嫌了bad word")
break # Break the loop if a bad word is detected
# Otherwise, print the generated response
......
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("../../chatglm3-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("../../chatglm3-6b", trust_remote_code=True, device='cuda')
model = model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
print(response)
# print(len(tokenizer))
# vocab_content = tokenizer.get_vocab()
# with open("vocab.txt", "w", encoding="utf-8") as f:
# for token, index in vocab_content.items():
# f.write(f"{token} {index}\n")
\ No newline at end of file
import os
from typing import Dict, Union, Optional
from torch.nn import Module
from transformers import AutoModel
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
# transformer.word_embeddings 占用1层
# transformer.final_layernorm 和 lm_head 占用1层
# transformer.layers 占用 28 层
# 总共30层分配到num_gpus张卡上
num_trans_layers = 28
per_gpu_layers = 30 / num_gpus
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
# windows下 model.device 会被设置成 transformer.word_embeddings.device
# linux下 model.device 会被设置成 lm_head.device
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
# 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
# 仅此处做少许修改以支持ChatGLM3
device_map = {
'transformer.embedding.word_embeddings': 0,
'transformer.encoder.final_layernorm': 0,
'transformer.output_layer': 0,
'transformer.rotary_pos_emb': 0,
'lm_head': 0
}
used = 2
gpu_target = 0
for i in range(num_trans_layers):
if used >= per_gpu_layers:
gpu_target += 1
used = 0
assert gpu_target < num_gpus
device_map[f'transformer.encoder.layers.{i}'] = gpu_target
used += 1
return device_map
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
if num_gpus < 2 and device_map is None:
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
else:
from accelerate import dispatch_model
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
if device_map is None:
device_map = auto_configure_device_map(num_gpus)
model = dispatch_model(model, device_map=device_map)
return model
\ No newline at end of file
This diff is collapsed.
import os
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
from utils import load_model_on_gpus
import torch
MODEL_PATH = os.environ.get('MODEL_PATH', '../../chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
chatbot.append((parse_text(input), ""))
for response, history, past_key_values in model.stream_chat(tokenizer, input, history,
past_key_values=past_key_values,
return_past_key_values=True,
max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history, past_key_values
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], [], None
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM3-6B</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
past_key_values = gr.State(None)
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
[chatbot, history, past_key_values], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
demo.queue().launch(share=False, server_name="127.0.0.1", server_port=8501, inbrowser=True)
"""
This script is a simple web demo based on Streamlit, showcasing the use of the ChatGLM3-6B model. For a more comprehensive web demo,
it is recommended to use 'composite_demo'.
Usage:
- Run the script using Streamlit: `streamlit run web_demo_streamlit.py`
- Adjust the model parameters from the sidebar.
- Enter questions in the chat input box and interact with the ChatGLM3-6B model.
Note: Ensure 'streamlit' and 'transformers' libraries are installed and the required model checkpoints are available.
"""
import os
import streamlit as st
import torch
......@@ -17,34 +5,41 @@ from transformers import AutoModel, AutoTokenizer
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# 设置页面标题、图标和布局
st.set_page_config(
page_title="ChatGLM3-6B Streamlit Simple Demo",
page_title="ChatGLM3-6B 演示",
page_icon=":robot:",
layout="wide"
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()
else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
return tokenizer, model
# 加载Chatglm3的model和tokenizer
tokenizer, model = get_model()
# 初始化历史记录和past key values
if "history" not in st.session_state:
st.session_state.history = []
if "past_key_values" not in st.session_state:
st.session_state.past_key_values = None
# 设置max_length、top_p和temperature
max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01)
# 清理会话历史
buttonClean = st.sidebar.button("清理会话历史", key="clean")
if buttonClean:
st.session_state.history = []
......@@ -53,6 +48,7 @@ if buttonClean:
torch.cuda.empty_cache()
st.rerun()
# 渲染聊天历史记录
for i, message in enumerate(st.session_state.history):
if message["role"] == "user":
with st.chat_message(name="user", avatar="user"):
......@@ -61,26 +57,33 @@ for i, message in enumerate(st.session_state.history):
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown(message["content"])
# 输入框和输出框
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
# 获取用户输入
prompt_text = st.chat_input("请输入您的问题")
# 如果用户输入了内容,则生成回复
if prompt_text:
input_placeholder.markdown(prompt_text)
history = st.session_state.history
past_key_values = st.session_state.past_key_values
for response, history, past_key_values in model.stream_chat(
tokenizer,
prompt_text,
history,
past_key_values=past_key_values,
max_length=max_length,
top_p=top_p,
temperature=temperature,
return_past_key_values=True,
tokenizer,
prompt_text,
history,
past_key_values=past_key_values,
max_length=max_length,
top_p=top_p,
temperature=temperature,
return_past_key_values=True,
):
message_placeholder.markdown(response)
# 更新历史记录和past key values
st.session_state.history = history
st.session_state.past_key_values = past_key_values
"""
This script creates an interactive web demo for the ChatGLM3-6B model using Gradio,
a Python library for building quick and easy UI components for machine learning models.
It's designed to showcase the capabilities of the ChatGLM3-6B model in a user-friendly interface,
allowing users to interact with the model through a chat-like interface.
Usage:
- Run the script to start the Gradio web server.
- Interact with the model by typing questions and receiving responses.
Requirements:
- Gradio (required for 4.13.0 and later, 3.x is not support now) should be installed.
Note: The script includes a modification to the Chatbot's postprocess method to handle markdown to HTML conversion,
ensuring that the chat interface displays formatted text correctly.
"""
import os
import gradio as gr
import torch
from threading import Thread
from typing import Union, Annotated
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [0, 2]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def predict(history, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
print("\n\n====conversation====\n", messages)
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token != '':
history[-1][1] += new_token
yield history
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
def user(query, history):
return "", history + [[parse_text(query), ""]]
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(server_name="127.0.0.1", server_port=7870, inbrowser=True, share=False)
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
from __future__ import annotations
import os
import streamlit as st
import torch
from collections.abc import Iterable
import os
from typing import Any, Protocol
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from conversation import Conversation
......@@ -17,13 +15,30 @@ TOOL_PROMPT = 'Answer the following questions as best as you can. You have acces
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
PT_PATH = os.environ.get('PT_PATH', None)
PRE_SEQ_LEN = int(os.environ.get("PRE_SEQ_LEN", 128))
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# for Mac Computer like M1
# You Need Use Pytorch compiled with Metal
# DEVICE = 'mps'
# for AMD gpu likes MI100 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with ROCm
# DEVICE = 'cuda'
# for Intel gpu likes A770 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with oneDNN and install intel-extension-for-pytorch
# import intel_extension_for_pytorch as ipex
# DEVICE = 'xpu'
# for Moore Threads gpu like MTT S80 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with Musa
# DEVICE = 'musa'
@st.cache_resource
def get_client() -> Client:
client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH)
client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH, DEVICE)
return client
......@@ -37,20 +52,13 @@ class Client(Protocol):
...
def stream_chat(
self, tokenizer, query: str,
history: list[tuple[str, str]] = None,
role: str = "user",
past_key_values=None,
max_new_tokens: int = 256,
do_sample=True, top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
length_penalty=1.0, num_beams=1,
logits_processor=None,
return_past_key_values=False,
**kwargs
):
def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
repetition_penalty=1.0, length_penalty=1.0, num_beams=1,
logits_processor=None, return_past_key_values=False, **kwargs):
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
......@@ -60,16 +68,12 @@ def stream_chat(
if history is None:
history = []
print("\n== Input ==\n", query)
print("\n==History==\n", history)
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_new_tokens": max_new_tokens,
gen_kwargs = {"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
......@@ -80,6 +84,7 @@ def stream_chat(
**kwargs
}
print(gen_kwargs)
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
......@@ -94,10 +99,13 @@ def stream_chat(
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
print("input_shape>", inputs['input_ids'].shape)
input_sequence_length = inputs['input_ids'].shape[1]
if input_sequence_length + max_new_tokens >= self.config.seq_length:
yield "Current input sequence length {} plus max_new_tokens {} is too long. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
input_sequence_length, max_new_tokens, self.config.seq_length
if max_length < input_sequence_length <= self.config.seq_length:
yield "Current input sequence length {} exceeds sequence length set in generation parameters {}. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
input_sequence_length, max_length, self.config.seq_length
), history
return
......@@ -123,23 +131,13 @@ def stream_chat(
class HFClient(Client):
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None):
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | None = None, DEVICE = 'cpu'):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
if pt_checkpoint is not None and os.path.exists(pt_checkpoint):
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=True,
pre_seq_len=PRE_SEQ_LEN
)
self.model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
config=config,
device_map="auto").eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model
# must use cuda to load int4 model
if pt_checkpoint is not None:
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, config=config)
prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
......@@ -148,17 +146,17 @@ class HFClient(Client):
print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
self.model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model
# must use cuda to load int4 model
def generate_stream(
self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
self.model = self.model.to(DEVICE).eval() if 'cuda' in DEVICE else self.model.float().to(DEVICE).eval()
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
chat_history = [{
'role': 'system',
'content': system if not tools else TOOL_PROMPT,
......@@ -175,15 +173,16 @@ class HFClient(Client):
query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
text = ''
for new_text, _ in stream_chat(
self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
for new_text, _ in stream_chat(self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
word = new_text.removeprefix(text)
word_stripped = word.strip()
text = new_text
......
......@@ -70,7 +70,7 @@ class Conversation:
text = postprocess_text(self.content)
match self.role.value:
case Role.TOOL.value:
text = f'Calling tool `{self.tool}`:\n\n{text}'
text = f'Calling tool `{self.tool}`:\n{text}'
case Role.INTERPRETER.value:
text = f'{text}'
case Role.OBSERVATION.value:
......
......@@ -4,6 +4,8 @@ from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
MAX_LENGTH = 8192
client = get_client()
......@@ -17,43 +19,31 @@ def append_conversation(
conversation.show(placeholder)
def main(
prompt_text: str,
system_prompt: str,
top_p: float = 0.8,
temperature: float = 0.95,
repetition_penalty: float = 1.0,
max_new_tokens: int = 1024,
retry: bool = False
):
def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str, repetition_penalty: float):
placeholder = st.empty()
with placeholder.container():
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
del history[last_user_conversation_idx:]
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
append_conversation(Conversation(Role.USER, prompt_text), history)
input_text = preprocess_text(
system_prompt,
tools=None,
history=history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.empty()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
......@@ -64,7 +54,7 @@ def main(
tools=None,
history=history,
do_sample=True,
max_new_tokens=max_new_tokens,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(Role.USER)],
......@@ -72,7 +62,9 @@ def main(
):
token = response.token
if response.token.special:
print("\n==Output:==\n", output_text)
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
break
......@@ -85,4 +77,4 @@ def main(
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
\ No newline at end of file
), history, markdown_placeholder)
......@@ -18,8 +18,10 @@ IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo')
SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。'
client = get_client()
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
client = get_client()
class CodeKernel(object):
def __init__(self,
......@@ -38,14 +40,14 @@ class CodeKernel(object):
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose
if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
......@@ -82,7 +84,7 @@ class CodeKernel(object):
break
except queue.Empty:
break
return shell_msg, msg_out
except Exception as e:
print(e)
......@@ -151,18 +153,15 @@ class CodeKernel(object):
def is_alive(self):
return self.kernel.is_alive()
def b64_2_img(data):
buff = BytesIO(base64.b64decode(data))
return Image.open(buff)
def clean_ansi_codes(input_string):
ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', input_string)
def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
res = ""
res_type = None
......@@ -172,12 +171,12 @@ def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
res_type = "text"
res = output['text']
......@@ -195,68 +194,52 @@ def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
return res_type, b64_2_img(res)
elif res_type == "text" or res_type == "traceback":
res = res
return res_type, res
@st.cache_resource
def get_kernel():
kernel = CodeKernel()
return kernel
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(
prompt_text: str,
top_p: float = 0.2,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
max_new_tokens: int = 1024,
truncate_length: int = 1024,
retry: bool = False
):
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
if 'ci_history' not in st.session_state:
st.session_state.ci_history = []
history: list[Conversation] = st.session_state.ci_history
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
del history[last_user_conversation_idx:]
if prompt_text:
prompt_text = prompt_text.strip()
role = Role.USER
append_conversation(Conversation(role, prompt_text), history)
input_text = preprocess_text(
SYSTEM_PROMPT,
None,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
......@@ -264,19 +247,21 @@ def main(
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=SYSTEM_PROMPT,
tools=None,
history=history,
do_sample=True,
max_new_token=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
system=SYSTEM_PROMPT,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("\n==Output:==\n", output_text)
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
......@@ -296,6 +281,7 @@ def main(
continue
case '<|observation|>':
code = extract_code(output_text)
print("Code:", code)
display_text = output_text.split('interpreter')[-1].strip()
append_conversation(Conversation(
......@@ -305,7 +291,7 @@ def main(
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
output_text = ''
with markdown_placeholder:
with st.spinner('Executing code...'):
try:
......@@ -314,9 +300,9 @@ def main(
st.error(f'Error when executing code: {e}')
return
print("Received:", res_type, res)
if truncate_length:
if res_type == 'text' and len(res) > truncate_length:
res = res[:truncate_length] + ' [TRUNCATED]'
if res_type == 'text' and len(res) > TRUNCATE_LENGTH:
res = res[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION,
......@@ -340,5 +326,4 @@ def main(
postprocess_text(output_text),
), history, markdown_placeholder)
return
else:
st.session_state.chat_history = []
......@@ -9,6 +9,9 @@ from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
EXAMPLE_TOOL = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
......@@ -27,51 +30,37 @@ EXAMPLE_TOOL = {
client = get_client()
def tool_call(*args, **kwargs) -> dict:
print("=== Tool call===")
print("=== Tool call:")
print(args)
print(kwargs)
st.session_state.calling_tool = True
return kwargs
def yaml_to_dict(tools: str) -> list[dict] | None:
try:
return yaml.safe_load(tools)
except YAMLError:
return None
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
print(matches)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(
prompt_text: str,
top_p: float = 0.2,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
max_new_tokens: int = 1024,
truncate_length: int = 1024,
retry: bool = False
):
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
manual_mode = st.toggle('Manual mode',
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
if manual_mode:
with st.expander('Tools'):
......@@ -92,34 +81,27 @@ def main(
if 'calling_tool' not in st.session_state:
st.session_state.calling_tool = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.tool_history
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
del history[last_user_conversation_idx:]
if prompt_text:
prompt_text = prompt_text.strip()
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
append_conversation(Conversation(role, prompt_text), history)
st.session_state.calling_tool = False
input_text = preprocess_text(
None,
tools,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
......@@ -127,19 +109,21 @@ def main(
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=None,
tools=tools,
history=history,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
system=None,
tools=tools,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("\n==Output:==\n", output_text)
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
......@@ -160,7 +144,7 @@ def main(
case '<|observation|>':
tool, *call_args_text = output_text.strip().split('\n')
call_args_text = '\n'.join(call_args_text)
append_conversation(Conversation(
Role.TOOL,
postprocess_text(output_text),
......@@ -168,16 +152,16 @@ def main(
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
try:
code = extract_code(call_args_text)
args = eval(code, {'tool_call': tool_call}, {})
except:
st.error('Failed to parse tool call')
return
output_text = ''
if manual_mode:
st.info('Please provide tool call results below:')
return
......@@ -186,8 +170,8 @@ def main(
with st.spinner(f'Calling tool {tool}...'):
observation = dispatch_tool(tool, args)
if len(observation) > truncate_length:
observation = observation[:truncate_length] + ' [TRUNCATED]'
if len(observation) > TRUNCATE_LENGTH:
observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION, observation
), history, markdown_placeholder)
......
from enum import Enum
import streamlit as st
st.set_page_config(
page_title="ChatGLM3 Demo",
page_icon=":robot:",
......@@ -6,9 +8,7 @@ st.set_page_config(
initial_sidebar_state='expanded',
)
import demo_chat, demo_ci, demo_tool
from enum import Enum
DEFAULT_SYSTEM_PROMPT = '''
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
......@@ -18,10 +18,7 @@ You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's
st.title("ChatGLM3 Demo")
# Add your custom text here, with smaller font size
st.markdown(
"<sub>智谱AI 公开在线技术文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof </sub> \n\n <sub> 更多 ChatGLM3-6B 的使用方法请参考文档。</sub>",
unsafe_allow_html=True)
st.markdown("<sub>智谱AI 公开在线技术文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof </sub> \n\n <sub> 更多 ChatGLM3-6B 的使用方法请参考文档。</sub>", unsafe_allow_html=True)
class Mode(str, Enum):
CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑‍💻 Code Interpreter'
......@@ -35,17 +32,8 @@ with st.sidebar:
'temperature', 0.0, 1.5, 0.95, step=0.01
)
repetition_penalty = st.slider(
'repetition_penalty', 0.0, 2.0, 1.1, step=0.01
'repetition_penalty', 0.0, 2.0, 1.2, step=0.01
)
max_new_token = st.slider(
'Output length', 5, 32000, 256, step=1
)
cols = st.columns(2)
export_btn = cols[0]
clear_history = cols[1].button("Clear History", use_container_width=True)
retry = export_btn.button("Retry", use_container_width=True)
system_prompt = st.text_area(
label="System Prompt (Only for chat mode)",
height=300,
......@@ -64,37 +52,12 @@ tab = st.radio(
label_visibility='hidden',
)
if clear_history or retry:
prompt_text = ""
match tab:
case Mode.CHAT:
demo_chat.main(
retry=retry,
top_p=top_p,
temperature=temperature,
prompt_text=prompt_text,
system_prompt=system_prompt,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_token
)
demo_chat.main(top_p, temperature, system_prompt, prompt_text, repetition_penalty)
case Mode.TOOL:
demo_tool.main(
retry=retry,
top_p=top_p,
temperature=temperature,
prompt_text=prompt_text,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_token,
truncate_length=1024)
demo_tool.main(top_p, temperature, prompt_text, repetition_penalty)
case Mode.CI:
demo_ci.main(
retry=retry,
top_p=top_p,
temperature=temperature,
prompt_text=prompt_text,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_token,
truncate_length=1024)
demo_ci.main(top_p, temperature, prompt_text, repetition_penalty)
case _:
st.error(f'Unexpected tab: {tab}')
huggingface_hub>=0.19.4
pillow>=10.1.0
pyyaml>=6.0.1
requests>=2.31.0
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
huggingface_hub
ipykernel
ipython
jupyter_client
pillow
sentencepiece
streamlit
tokenizers
pyyaml
requests
\ No newline at end of file
"""
This code is the tool registration part. By registering the tool, the model can call the tool.
This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities
through defined interfaces.
"""
import copy
import inspect
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated
import subprocess
_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = {}
def register_tool(func: callable):
tool_name = func.__name__
tool_description = inspect.getdoc(func).strip()
......@@ -27,7 +19,7 @@ def register_tool(func: callable):
raise TypeError(f"Parameter `{name}` missing type annotation")
if get_origin(annotation) != Annotated:
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
typ, (description, required) = annotation.__origin__, annotation.__metadata__
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
if not isinstance(description, str):
......@@ -46,34 +38,32 @@ def register_tool(func: callable):
"description": tool_description,
"params": tool_params
}
print("[registered tool] " + pformat(tool_def))
_TOOL_HOOKS[tool_name] = func
_TOOL_DESCRIPTIONS[tool_name] = tool_def
return func
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
if tool_name not in _TOOL_HOOKS:
return f"Tool `{tool_name}` not found. Please use a provided tool."
tool_call = _TOOL_HOOKS[tool_name]
try:
ret = tool_call(**tool_params)
ret = tool_call(**tool_params)
except:
ret = traceback.format_exc()
return str(ret)
def get_tools() -> dict:
return copy.deepcopy(_TOOL_DESCRIPTIONS)
# Tool Definitions
@register_tool
def random_number_generator(
seed: Annotated[int, 'The random seed used by the generator', True],
range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
seed: Annotated[int, 'The random seed used by the generator', True],
range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
......@@ -88,10 +78,9 @@ def random_number_generator(
import random
return random.Random(seed).randint(*range)
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the current weather for `city_name`
......@@ -101,7 +90,7 @@ def get_weather(
raise TypeError("City name must be a string")
key_selection = {
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
}
import requests
try:
......@@ -111,28 +100,10 @@ def get_weather(
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
except:
import traceback
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
return str(ret)
@register_tool
def get_shell(
query: Annotated[str, 'The command should run in Linux shell', True],
) -> str:
"""
Use shell to run command
"""
if not isinstance(query, str):
raise TypeError("Command must be a string")
try:
result = subprocess.run(query, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True)
return result.stdout
except subprocess.CalledProcessError as e:
return e.stderr
if __name__ == "__main__":
# print(dispatch_tool("get_shell", {"query": "pwd"}))
print(get_tools())
\ No newline at end of file
print(dispatch_tool("get_weather", {"city_name": "beijing"}))
print(get_tools())
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 使用精确的提示词完成符合格式输出的文本分类任务\n",
"\n",
"在本章节中,我们将带领开发者体验在 `文本分类` 任务中,在不进行任何模型微调的情况下,使用具有微小差异的提示词,并使用 `ChatGLM3-6B` 模型来完成符合格式输出的文本分类任务。\n",
"\n",
"我们使用 [新闻标题分类](https://github.com/fateleak/toutiao-multilevel-text-classfication-dataset) 任务来体验模型的表现。这是一个经典的文本分类任务,我们将使用 `新闻信息` 作为输入,模型需要预测出这个标题属于哪个类别。\n",
"\n",
"由于 `ChatGLM3-6B` 强大的能力,我们可以直接使用 `新闻标题` 作为输入,而不需要额外的信息,也不需要进行任何模型微调,就能完成这个任务。我们的目标是,让模型能够成功输出原始数据集中的15种类别中的一种作为分类结果,而且不能输入任何冗余的对话。\n",
"\n",
"在本章节中,用户将直观的对比两种不同细粒度的提示词下对模型分类造成的影响。\n",
"\n",
"## 硬件要求\n",
"本实践手册需要使用 FP16 精度的模型进行推理,因此,我们推荐使用至少 16GB 显存的 英伟达 GPU 来完成本实践手册。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"引入必要的库"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"from transformers import AutoModel, AutoTokenizer\n",
"import torch"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:03.002351Z",
"end_time": "2023-11-23T19:23:03.016701Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"设置好对应的参数,以保证模型推理的公平性。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"max_new_tokens = 1024\n",
"temperature = 0.1\n",
"top_p = 0.9\n",
"device = \"cuda\"\n",
"model_path_chat = \"/Models/chatglm3-6b\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:03.004850Z",
"end_time": "2023-11-23T19:23:03.047154Z"
}
}
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Loading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "d6468f7889994e638ac754fde95b6e58"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer_path_chat = model_path_chat\n",
"tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_chat, trust_remote_code=True, encode_special_tokens=True)\n",
"model = AutoModel.from_pretrained(model_path_chat, load_in_8bit=False, trust_remote_code=True).to(device)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:03.047154Z",
"end_time": "2023-11-23T19:23:13.086148Z"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"def answer(prompt):\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
" response = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=max_new_tokens, history=None,\n",
" temperature=temperature,\n",
" top_p=top_p, do_sample=True)\n",
" response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
" answer = tokenizer.decode(response, skip_special_tokens=True)\n",
" return answer"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:13.086148Z",
"end_time": "2023-11-23T19:23:13.086148Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"在本样例中,模型应该输出的标准答案应该为:\n",
"'news_sports'"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [],
"source": [
"PROMPT1 = \"\"\"\n",
"<|system|>\n",
"你是一个专业的新闻专家,请根据我提供的新闻信息,包括新闻标题,新闻关键词等信息,你需要在这些类中选择其中一个,它们分别是:\n",
"news_story\n",
"news_culture\n",
"news_sports\n",
"news_finance\n",
"news_house\n",
"news_car\n",
"news_edu\n",
"news_tech\n",
"news_military\n",
"news_travel\n",
"news_world\n",
"stock\n",
"news_agriculture\n",
"news_game\n",
"这是我的信息:\n",
"<|user|>\n",
"新闻标题: 女乒今天排兵布阵不合理,丁宁昨天刚打硬仗,今天应该打第一单打,你认为呢?\n",
"新闻关键词: 无\n",
"<|assistant|>\n",
"\"\"\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:13.086148Z",
"end_time": "2023-11-23T19:23:13.095864Z"
}
}
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [
{
"data": {
"text/plain": "' 新闻类型:体育新闻\\n 新闻关键词:女乒,排兵布阵,丁宁,硬仗,第一单打'"
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer(PROMPT1)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:13.095864Z",
"end_time": "2023-11-23T19:23:14.034154Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"很明显,使用这个提示词没有完成要求。我们更换一个提示词,经过优化后,看是否能达标。\n",
"\n",
"我们为提示词中设定了更多的限定词汇,并用Markdown语法规范化了输出的格式。修改后的提示词如下:"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [],
"source": [
"PROMPT2 = \"\"\"\n",
"<|system|>\n",
"请根据我提供的新闻信息,格式为\n",
"```\n",
"新闻标题: xxx\n",
"新闻关键词: xxx\n",
"```\n",
"你要对每一行新闻类别进行分类并告诉我结果,不要返回其他信息和多于的文字,这些类别是:\n",
"news_story\n",
"news_culture\n",
"news_sports\n",
"news_finance\n",
"news_house\n",
"news_car\n",
"news_edu\n",
"news_tech\n",
"news_military\n",
"news_travel\n",
"news_world\n",
"stock\n",
"news_agriculture\n",
"news_game\n",
"我将为你提供一些新闻标题和关键词,你需要根据这些信息,对这些新闻进行分类,每条一行,格式为\n",
"```\n",
"类别中的其中一个\n",
"```\n",
"不要返回其他内容,例如新闻标题,新闻关键词等等,只需要返回分类结果即可\n",
"<|user|>\n",
"新闻标题: 女乒今天排兵布阵不合理,丁宁昨天刚打硬仗,今天应该打第一单打,你认为呢?\n",
"新闻关键词: 无\n",
"<|assistant|>\n",
"\"\"\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:14.044595Z",
"end_time": "2023-11-23T19:23:14.044595Z"
}
}
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [
{
"data": {
"text/plain": "'news_sports'"
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer(PROMPT2)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:14.044595Z",
"end_time": "2023-11-23T19:23:14.217801Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"这一次,模型成功给出了理想的答案。我们结束实操训练,删除模型并释放显存。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [],
"source": [
"del model\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:14.217801Z",
"end_time": "2023-11-23T19:23:14.373137Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 总结\n",
"在本实践手册中,我们让开发者体验了使用不同提示词下,`ChatGLM3-6B` 模型在 `新闻标题分类` 任务中的表现。\n",
"1. 模型在两次提示中都能正确的分类,验证了模型底座的能力。\n",
"2. 在使用第二个提示词时,模型的输出符合格式要求,而且没有冗余的对话。符合我们的要求。\n",
"3. 第二个提示词使用了一定的表示限定的词汇和更准确的任务要求,包括规定了返回格式,返回的行书等的。\n",
"\n",
"因此,通过上述内容,我们可以发现:\n",
"对于有格式要求的任务,可以使用更具有格式化的提示词来完成任务。同时,使用更低的 `temperature` 和更高的 `top_p` 可以提高模型的输出质量。\n",
"\n",
"模型在训练中大量使用 Markdown格式。因此,用 ``` 符号来规范提示词将能提升模型输出格式化数据的准确率。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment