Commit 9eb7f37f authored by Rayyyyy's avatar Rayyyyy
Browse files

First add

parents
"""
This is a client part of composite_demo.
We provide two clients, HFClient and VLLMClient, which are used to interact with the model.
The HFClient is used to interact with the transformers backend, and the VLLMClient is used to interact with the VLLM model.
"""
import json
from collections.abc import Generator
from copy import deepcopy
from enum import Enum, auto
from typing import Protocol
import streamlit as st
from conversation import Conversation, build_system_prompt
from tools.tool_registry import ALL_TOOLS
class ClientType(Enum):
HF = auto()
VLLM = auto()
class Client(Protocol):
def __init__(self, model_path: str): ...
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]: ...
def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
chat_history = []
if len(tools) > 0:
chat_history.append(
{"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}
)
for conversation in history:
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
item = {
"role": role,
"content": conversation.content,
}
if conversation.metadata:
item["metadata"] = conversation.metadata
# Only append image for user
if role == "user" and conversation.image:
item["image"] = conversation.image
chat_history.append(item)
return chat_history
def process_response(output, history):
content = ""
history = deepcopy(history)
for response in output.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if not metadata.strip():
content = content.strip()
history.append({"role": "assistant", "metadata": metadata, "content": content})
content = content.replace("[[训练时间]]", "2023年")
else:
history.append({"role": "assistant", "metadata": metadata, "content": content})
if history[0]["role"] == "system" and "tools" in history[0]:
parameters = json.loads(content)
content = {"name": metadata.strip(), "parameters": parameters}
else:
content = {"name": metadata.strip(), "content": content}
return content, history
# glm-4v-9b is not available in vLLM backend, use HFClient instead.
@st.cache_resource(max_entries=1, show_spinner="Loading model...")
def get_client(model_path, typ: ClientType) -> Client:
match typ:
case ClientType.HF:
from clients.hf import HFClient
return HFClient(model_path)
case ClientType.VLLM:
try:
from clients.vllm import VLLMClient
except ImportError as e:
e.msg += "; did you forget to install vLLM?"
raise
return VLLMClient(model_path)
raise NotImplementedError(f"Client type {typ} is not supported.")
"""
HuggingFace client.
"""
import threading
from collections.abc import Generator
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from client import Client, process_input, process_response
from conversation import Conversation
class HFClient(Client):
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools)
model_inputs = self.tokenizer.apply_chat_template(
chat_history,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to(self.model.device)
streamer = TextIteratorStreamer(
tokenizer=self.tokenizer,
timeout=5,
skip_prompt=True,
)
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"eos_token_id": [151329, 151336, 151338],
"do_sample": True,
}
generate_kwargs.update(parameters)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
total_text = ""
for token_text in streamer:
total_text += token_text
yield process_response(total_text, chat_history)
"""
vLLM client.
Please install [vLLM](https://github.com/vllm-project/vllm) according to its
installation guide before running this client.
"""
import time
from collections.abc import Generator
from transformers import AutoTokenizer
from vllm import SamplingParams, LLMEngine, EngineArgs
from client import Client, process_input, process_response
from conversation import Conversation
class VLLMClient(Client):
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.engine_args = EngineArgs(
model=model_path,
tensor_parallel_size=1,
dtype="bfloat16", # torch.bfloat16 is needed.
trust_remote_code=True,
gpu_memory_utilization=0.6,
enforce_eager=True,
worker_use_ray=False,
)
self.engine = LLMEngine.from_engine_args(self.engine_args)
def generate_stream(
self, tools: list[dict], history: list[Conversation], **parameters
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools)
model_inputs = self.tokenizer.apply_chat_template(
chat_history, add_generation_prompt=True, tokenize=False
)
parameters["max_tokens"] = parameters.pop("max_new_tokens")
params_dict = {
"n": 1,
"best_of": 1,
"top_p": 1,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"logprobs": None,
"prompt_logprobs": None,
}
params_dict.update(parameters)
sampling_params = SamplingParams(**params_dict)
self.engine.add_request(
request_id=str(time.time()), inputs=model_inputs, params=sampling_params
)
while self.engine.has_unfinished_requests():
request_outputs = self.engine.step()
for request_output in request_outputs:
yield process_response(request_output.outputs[0].text, chat_history)
import json
import re
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL.Image import Image
from tools.browser import Quote, quotes
QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】")
SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
DATE_PROMPT = "当前日期: %Y-%m-%d"
TOOL_SYSTEM_PROMPTS = {
"python": "当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。",
"simple_browser": "你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。",
"cogview": "如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。",
}
FILE_TEMPLATE = "[File Name]\n{file_name}\n[File Content]\n{file_content}"
def build_system_prompt(
enabled_tools: list[str],
functions: list[dict],
):
value = SELFCOG_PROMPT
value += "\n\n" + datetime.now().strftime(DATE_PROMPT)
value += "\n\n# 可用工具"
contents = []
for tool in enabled_tools:
contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}")
for function in functions:
content = f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
contents.append(content)
value += "".join(contents)
return value
def response_to_str(response: str | dict[str, str]) -> str:
"""
Convert response to string.
"""
if isinstance(response, dict):
return response.get("name", "") + response.get("content", "")
return response
class Role(Enum):
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
OBSERVATION = auto()
def __str__(self):
match self:
case Role.SYSTEM:
return "<|system|>"
case Role.USER:
return "<|user|>"
case Role.ASSISTANT | Role.TOOL:
return "<|assistant|>"
case Role.OBSERVATION:
return "<|observation|>"
# Get the message block for the given role
def get_message(self):
# Compare by value here, because the enum object in the session state
# is not the same as the enum cases here, due to streamlit's rerunning
# behavior.
match self.value:
case Role.SYSTEM.value:
return
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="assistant")
case _:
st.error(f"Unexpected role: {self}")
@dataclass
class Conversation:
role: Role
content: str | dict
# Processed content
saved_content: str | None = None
metadata: str | None = None
image: str | Image | None = None
def __str__(self) -> str:
metadata_str = self.metadata if self.metadata else ""
return f"{self.role}{metadata_str}\n{self.content}"
# Human readable format
def get_text(self) -> str:
text = self.saved_content or self.content
match self.role.value:
case Role.TOOL.value:
text = f"Calling tool `{self.metadata}`:\n\n```python\n{text}\n```"
case Role.OBSERVATION.value:
text = f"```python\n{text}\n```"
return text
# Display as a markdown block
def show(self, placeholder: DeltaGenerator | None = None) -> str:
if placeholder:
message = placeholder
else:
message = self.role.get_message()
if self.image:
message.image(self.image, width=512)
if self.role == Role.OBSERVATION:
metadata_str = f"from {self.metadata}" if self.metadata else ""
message = message.expander(f"Observation {metadata_str}")
text = self.get_text()
if self.role != Role.USER:
show_text = text
else:
splitted = text.split('files uploaded.\n')
if len(splitted) == 1:
show_text = text
else:
# Show expander for document content
doc = splitted[0]
show_text = splitted[-1]
expander = message.expander(f'File Content')
expander.markdown(doc)
message.markdown(show_text)
def postprocess_text(text: str, replace_quote: bool) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
text = text.replace("<|endoftext|>", "")
# Replace quotes
if replace_quote:
for match in QUOTE_REGEX.finditer(text):
quote_id = match.group(1)
quote = quotes.get(quote_id, Quote("未找到引用内容", ""))
text = text.replace(
match.group(0), f" (来源:[{quote.title}]({quote.url})) "
)
return text.strip()
"""
This demo show the All tools and Long Context chat Capabilities of GLM-4.
Please follow the Readme.md to run the demo.
"""
import os
import traceback
from enum import Enum
from io import BytesIO
from uuid import uuid4
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL import Image
from client import Client, ClientType, get_client
from conversation import (
FILE_TEMPLATE,
Conversation,
Role,
postprocess_text,
response_to_str,
)
from tools.tool_registry import dispatch_tool, get_tools
from utils import extract_pdf, extract_docx, extract_pptx, extract_text
CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
class Mode(str, Enum):
ALL_TOOLS = "🛠️ All Tools"
LONG_CTX = "📝 文档解读"
VLM = "🖼️ 多模态"
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
"""
Append a conversation piece into history, meanwhile show it in a new markdown block
"""
history.append(conversation)
conversation.show(placeholder)
st.set_page_config(
page_title="GLM-4 Demo",
page_icon=":robot:",
layout="centered",
initial_sidebar_state="expanded",
)
st.title("GLM-4 Demo")
st.markdown(
"<sub>智谱AI 公开在线技术文档: https://zhipu-ai.feishu.cn/wiki/RuMswanpkiRh3Ok4z5acOABBnjf </sub> \n\n <sub> 更多 GLM-4 开源模型的使用方法请参考文档。</sub>",
unsafe_allow_html=True,
)
with st.sidebar:
top_p = st.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
top_k = st.slider("top_k", 1, 20, 10, step=1, key="top_k")
temperature = st.slider("temperature", 0.0, 1.5, 0.95, step=0.01)
repetition_penalty = st.slider("repetition_penalty", 0.0, 2.0, 1.0, step=0.01)
max_new_tokens = st.slider("max_new_tokens", 1, 4096, 2048, step=1)
cols = st.columns(2)
export_btn = cols[0]
clear_history = cols[1].button("Clear", use_container_width=True)
retry = export_btn.button("Retry", use_container_width=True)
if clear_history:
page = st.session_state.page
client = st.session_state.client
st.session_state.clear()
st.session_state.page = page
st.session_state.client = client
st.session_state.files_uploaded = False
st.session_state.uploaded_texts = ""
st.session_state.uploaded_file_nums = 0
st.session_state.history = []
if "files_uploaded" not in st.session_state:
st.session_state.files_uploaded = False
if "session_id" not in st.session_state:
st.session_state.session_id = uuid4()
if "history" not in st.session_state:
st.session_state.history = []
first_round = len(st.session_state.history) == 0
def build_client(mode: Mode) -> Client:
match mode:
case Mode.ALL_TOOLS:
st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
return get_client(CHAT_MODEL_PATH, typ)
case Mode.LONG_CTX:
st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
return get_client(CHAT_MODEL_PATH, typ)
case Mode.VLM:
st.session_state.top_k = 1
# vLLM is not available for VLM mode
return get_client(VLM_MODEL_PATH, ClientType.HF)
# Callback function for page change
def page_changed() -> None:
global client
new_page: str = st.session_state.page
st.session_state.history.clear()
st.session_state.client = build_client(Mode(new_page))
page = st.radio(
"选择功能",
[mode.value for mode in Mode],
key="page",
horizontal=True,
index=None,
label_visibility="hidden",
on_change=page_changed,
)
HELP = """
### 🎉 欢迎使用 GLM-4!
请在上方选取一个功能。每次切换功能时,将会重新加载模型并清空对话历史。
文档解读模式与 VLM 模式仅支持在第一轮传入文档或图像。
""".strip()
if page is None:
st.markdown(HELP)
exit()
if page == Mode.LONG_CTX:
if first_round:
uploaded_files = st.file_uploader(
"上传文件",
type=["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"],
accept_multiple_files=True,
)
if uploaded_files and not st.session_state.files_uploaded:
uploaded_texts = []
for uploaded_file in uploaded_files:
file_name: str = uploaded_file.name
random_file_name = str(uuid4())
file_extension = os.path.splitext(file_name)[1]
file_path = os.path.join("/tmp", random_file_name + file_extension)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
if file_name.endswith(".pdf"):
content = extract_pdf(file_path)
elif file_name.endswith(".docx"):
content = extract_docx(file_path)
elif file_name.endswith(".pptx"):
content = extract_pptx(file_path)
else:
content = extract_text(file_path)
uploaded_texts.append(
FILE_TEMPLATE.format(file_name=file_name, file_content=content)
)
os.remove(file_path)
st.session_state.uploaded_texts = "\n\n".join(uploaded_texts)
st.session_state.uploaded_file_nums = len(uploaded_files)
else:
st.session_state.uploaded_texts = ""
st.session_state.uploaded_file_nums = 0
elif page == Mode.VLM:
if first_round:
uploaded_image = st.file_uploader(
"上传图片",
type=["png", "jpg", "jpeg", "bmp", "tiff", "webp"],
accept_multiple_files=False,
)
if uploaded_image:
data: bytes = uploaded_image.read()
image = Image.open(BytesIO(data)).convert("RGB")
st.session_state.uploaded_image = image
else:
st.session_state.uploaded_image = None
prompt_text = st.chat_input("Chat with GLM-4!", key="chat_input")
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.history = []
exit()
history: list[Conversation] = st.session_state.history
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role.value == Role.USER.value:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
print(f"New prompt: {prompt_text}, idx = {last_user_conversation_idx}")
del history[last_user_conversation_idx:]
for conversation in history:
conversation.show()
tools = get_tools() if page == Mode.ALL_TOOLS else []
client: Client = st.session_state.client
def main(prompt_text: str):
global client
assert client is not None
if prompt_text:
prompt_text = prompt_text.strip()
# Append uploaded files
uploaded_texts = st.session_state.get("uploaded_texts")
if page == Mode.LONG_CTX and uploaded_texts and first_round:
meta_msg = "{} files uploaded.\n".format(
st.session_state.uploaded_file_nums
)
prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text
# Clear after first use
st.session_state.files_uploaded = True
st.session_state.uploaded_texts = ""
st.session_state.uploaded_file_nums = 0
image = st.session_state.get("uploaded_image")
if page == Mode.VLM and image and first_round:
st.session_state.uploaded_image = None
role = Role.USER
append_conversation(Conversation(role, prompt_text, image=image), history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(
name="assistant", avatar="assistant"
)
markdown_placeholder = message_placeholder.empty()
def add_new_block():
nonlocal message_placeholder, markdown_placeholder
message_placeholder = placeholder.chat_message(
name="assistant", avatar="assistant"
)
markdown_placeholder = message_placeholder.empty()
def commit_conversation(
role: Role,
text: str,
metadata: str | None = None,
image: str | None = None,
new: bool = False,
):
processed_text = postprocess_text(text, role.value == Role.ASSISTANT.value)
conversation = Conversation(role, text, processed_text, metadata, image)
# Use different placeholder for new block
placeholder = message_placeholder if new else markdown_placeholder
append_conversation(
conversation,
history,
placeholder,
)
response = ""
for _ in range(10):
last_response = None
history_len = None
try:
for response, chat_history in client.generate_stream(
tools=tools,
history=history,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
):
if history_len is None:
history_len = len(chat_history)
elif history_len != len(chat_history):
commit_conversation(Role.ASSISTANT, last_response)
add_new_block()
history_len = len(chat_history)
last_response = response
replace_quote = chat_history[-1]["role"] == "assistant"
markdown_placeholder.markdown(
postprocess_text(
str(response) + "●", replace_quote=replace_quote
)
)
else:
metadata = (
page == Mode.ALL_TOOLS
and isinstance(response, dict)
and response.get("name")
or None
)
role = Role.TOOL if metadata else Role.ASSISTANT
text = (
response.get("content")
if metadata
else response_to_str(response)
)
commit_conversation(role, text, metadata)
if metadata:
add_new_block()
try:
with markdown_placeholder:
with st.spinner(f"Calling tool {metadata}..."):
observations = dispatch_tool(
metadata, text, str(st.session_state.session_id)
)
except Exception as e:
traceback.print_exc()
st.error(f'Uncaught exception in `"{metadata}"`: {e}')
break
for observation in observations:
observation.text = observation.text
commit_conversation(
Role.OBSERVATION,
observation.text,
observation.role_metadata,
observation.image_url,
new=True,
)
add_new_block()
continue
else:
break
except Exception as e:
traceback.print_exc()
st.error(f"Uncaught exception: {traceback.format_exc()}")
else:
st.error("Too many chaining function calls!")
main(prompt_text)
"""
Simple browser tool.
# Usage
Please start the backend browser server according to the instructions in the README.
"""
from pprint import pprint
import re
import requests
import streamlit as st
from dataclasses import dataclass
from .config import BROWSER_SERVER_URL
from .interface import ToolObservation
QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]")
@dataclass
class Quote:
title: str
url: str
# Quotes for displaying reference
if "quotes" not in st.session_state:
st.session_state.quotes = {}
quotes: dict[str, Quote] = st.session_state.quotes
def map_response(response: dict) -> ToolObservation:
# Save quotes for reference
print('===BROWSER_RESPONSE===')
pprint(response)
role_metadata = response.get("roleMetadata")
metadata = response.get("metadata")
if role_metadata.split()[0] == 'quote_result' and metadata:
quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1)
quote: dict[str, str] = metadata['metadata_list'][0]
quotes[quote_id] = Quote(quote['title'], quote['url'])
elif role_metadata == 'browser_result' and metadata:
for i, quote in enumerate(metadata['metadata_list']):
quotes[str(i)] = Quote(quote['title'], quote['url'])
return ToolObservation(
content_type=response.get("contentType"),
text=response.get("result"),
role_metadata=role_metadata,
metadata=metadata,
)
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
request = {
"session_id": session_id,
"action": code,
}
response = requests.post(BROWSER_SERVER_URL, json=request).json()
return list(map(map_response, response))
import streamlit as st
from zhipuai import ZhipuAI
from zhipuai.types.image import GeneratedImage
from .config import COGVIEW_MODEL, ZHIPU_AI_KEY
from .interface import ToolObservation
@st.cache_resource
def get_zhipu_client():
return ZhipuAI(api_key=ZHIPU_AI_KEY)
def map_response(img: GeneratedImage):
return ToolObservation(
content_type='image',
text='CogView 已经生成并向用户展示了生成的图片。',
image_url=img.url,
role_metadata='cogview_result'
)
def tool_call(prompt: str, session_id: str) -> list[ToolObservation]:
client = get_zhipu_client()
response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data
return list(map(map_response, response))
BROWSER_SERVER_URL = 'http://localhost:3000'
IPYKERNEL = 'glm-4-demo'
ZHIPU_AI_KEY = ''
COGVIEW_MODEL = 'cogview-3'
from dataclasses import dataclass
from typing import Any
@dataclass
class ToolObservation:
content_type: str
text: str
image_url: str | None = None
role_metadata: str | None = None
metadata: Any = None
from pprint import pprint
import queue
import re
from subprocess import PIPE
from typing import Literal
import jupyter_client
import streamlit as st
from .config import IPYKERNEL
from .interface import ToolObservation
ANSI_ESCAPE = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
CODE = re.compile(r'```([^\n]*)\n(.*?)```')
class CodeKernel:
def __init__(self,
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
self.python_path = python_path
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose
if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_manager.connection_file))
if verbose:
pprint(self.kernel_manager.get_connection_info())
# Initialize the code kernel
self.kernel = self.kernel_manager.blocking_client()
# self.kernel.load_connection_file()
self.kernel.start_channels()
print("Code kernel started.")
def execute(self, code):
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
while True:
msg_out = io_msg_content
### Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
break
except queue.Empty:
break
return shell_msg, msg_out
except Exception as e:
print(e)
return None
def execute_interactive(self, code, verbose=False):
shell_msg = self.kernel.execute_interactive(code)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def inspect(self, code, verbose=False):
msg_id = self.kernel.inspect(code)
shell_msg = self.kernel.get_shell_msg(timeout=30)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error':
try:
error_msg = msg['content']['traceback']
except:
try:
error_msg = msg['content']['traceback'][-1].strip()
except:
error_msg = "Traceback Error"
if verbose:
print("Error: ", error_msg)
return error_msg
return None
def check_msg(self, msg, verbose=False):
status = msg['content']['status']
if status == 'ok':
if verbose:
print("Execution succeeded.")
elif status == 'error':
for line in msg['content']['traceback']:
if verbose:
print(line)
def shutdown(self):
# Shutdown the backend kernel
self.kernel_manager.shutdown_kernel()
print("Backend kernel shutdown.")
# Shutdown the code kernel
self.kernel.shutdown()
print("Code kernel shutdown.")
def restart(self):
# Restart the backend kernel
self.kernel_manager.restart_kernel()
# print("Backend kernel restarted.")
def interrupt(self):
# Interrupt the backend kernel
self.kernel_manager.interrupt_kernel()
# print("Backend kernel interrupted.")
def is_alive(self):
return self.kernel.is_alive()
def clean_ansi_codes(input_string):
return ANSI_ESCAPE.sub('', input_string)
def extract_code(text: str) -> str:
matches = CODE.findall(text, re.DOTALL)
return matches[-1][1]
def execute(
code: str,
kernel: CodeKernel
) -> tuple[Literal['text', 'image'] | None, str]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
code = code.replace("<|assistant|>python", "")
code = code.replace("<|assistant|>", "")
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
res_type = "text"
res = output['text']
elif 'data' in output:
for key in output['data']:
if 'text/plain' in key:
res_type = "text"
res = output['data'][key]
elif 'image/png' in key:
res_type = "image"
res = output['data'][key]
break
return res_type, res
@st.cache_resource
def get_kernel() -> CodeKernel:
return CodeKernel()
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
kernel = get_kernel()
res_type, res = execute(code, kernel)
# Convert base64 to data uri
text = '[Image]' if res_type == 'image' else res
image = f'data:image/png;base64,{res}' if res_type == 'image' else None
return [ToolObservation(res_type, text, image)]
"""
This code is the tool registration part. By registering the tool, the model can call the tool.
This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities
through defined interfaces.
"""
from collections.abc import Callable
import copy
import inspect
import json
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated
import subprocess
from .interface import ToolObservation
from .browser import tool_call as browser
from .cogview import tool_call as cogview
from .python import tool_call as python
ALL_TOOLS = {
"simple_browser": browser,
"python": python,
"cogview": cogview,
}
_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = []
def register_tool(func: Callable):
tool_name = func.__name__
tool_description = inspect.getdoc(func).strip()
python_params = inspect.signature(func).parameters
tool_params = []
for name, param in python_params.items():
annotation = param.annotation
if annotation is inspect.Parameter.empty:
raise TypeError(f"Parameter `{name}` missing type annotation")
if get_origin(annotation) != Annotated:
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
typ, (description, required) = annotation.__origin__, annotation.__metadata__
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
if not isinstance(description, str):
raise TypeError(f"Description for `{name}` must be a string")
if not isinstance(required, bool):
raise TypeError(f"Required for `{name}` must be a bool")
tool_params.append(
{
"name": name,
"description": description,
"type": typ,
"required": required,
}
)
tool_def = {
"name": tool_name,
"description": tool_description,
"params": tool_params,
}
# print("[registered tool] " + pformat(tool_def))
_TOOL_HOOKS[tool_name] = func
_TOOL_DESCRIPTIONS.append(tool_def)
return func
def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObservation]:
# Dispatch predefined tools
if tool_name in ALL_TOOLS:
return ALL_TOOLS[tool_name](code, session_id)
code = code.strip().rstrip('<|observation|>').strip()
# Dispatch custom tools
try:
tool_params = json.loads(code)
except json.JSONDecodeError as e:
err = f"Error decoding JSON: {e}"
return [ToolObservation("system_error", err)]
if tool_name not in _TOOL_HOOKS:
err = f"Tool `{tool_name}` not found. Please use a provided tool."
return [ToolObservation("system_error", err)]
tool_hook = _TOOL_HOOKS[tool_name]
try:
ret: str = tool_hook(**tool_params)
return [ToolObservation(tool_name, str(ret))]
except:
err = traceback.format_exc()
return [ToolObservation("system_error", err)]
def get_tools() -> list[dict]:
return copy.deepcopy(_TOOL_DESCRIPTIONS)
# Tool Definitions
@register_tool
def random_number_generator(
seed: Annotated[int, "The random seed used by the generator", True],
range: Annotated[tuple[int, int], "The range of the generated numbers", True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
"""
if not isinstance(seed, int):
raise TypeError("Seed must be an integer")
if not isinstance(range, tuple):
raise TypeError("Range must be a tuple")
if not isinstance(range[0], int) or not isinstance(range[1], int):
raise TypeError("Range must be a tuple of integers")
import random
return random.Random(seed).randint(*range)
@register_tool
def get_weather(
city_name: Annotated[str, "The name of the city to be queried", True],
) -> str:
"""
Get the current weather for `city_name`
"""
if not isinstance(city_name, str):
raise TypeError("City name must be a string")
key_selection = {
"current_condition": [
"temp_C",
"FeelsLikeC",
"humidity",
"weatherDesc",
"observation_time",
],
}
import requests
try:
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
resp.raise_for_status()
resp = resp.json()
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
except:
import traceback
ret = (
"Error encountered while fetching weather data!\n" + traceback.format_exc()
)
return str(ret)
@register_tool
def get_shell(
query: Annotated[str, "The command should run in Linux shell", True],
) -> str:
"""
Use shell to run command
"""
if not isinstance(query, str):
raise TypeError("Command must be a string")
try:
result = subprocess.run(
query,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return result.stdout
except subprocess.CalledProcessError as e:
return e.stderr
if __name__ == "__main__":
# print(dispatch_tool("get_shell", {"query": "pwd"}))
print(get_tools())
from langchain_community.document_loaders import PyMuPDFLoader
import docx
from pptx import Presentation
def extract_text(path):
return open(path, 'r').read()
def extract_pdf(path):
loader = PyMuPDFLoader(path)
data = loader.load()
data = [x.page_content for x in data]
content = '\n\n'.join(data)
return content
def extract_docx(path):
doc = docx.Document(path)
data = []
for paragraph in doc.paragraphs:
data.append(paragraph.text)
content = '\n\n'.join(data)
def extract_pptx(path):
prs = Presentation(path)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310
\ No newline at end of file
# GLM-4-9B Chat 对话模型微调
Read this in [English](README_en.md)
本 demo 中,你将体验到如何微调 GLM-4-9B-Chat 对话开源模型(不支持视觉理解模型)。 请严格按照文档的步骤进行操作,以避免不必要的错误。
## 硬件检查
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
测试硬件信息:
+ OS: Ubuntu 22.04
+ Memory: 512GB
+ Python: 3.12.3
+ CUDA Version: 12.3
+ GPU Driver: 535.104.05
+ GPU: NVIDIA A100-SXM4-80GB * 8
| 微调方案 | 显存占用 | 权重保存点大小 |
|--------------------|-----------------------------------|---------|
| lora (PEFT) | 21531MiB | 17M |
| p-tuning v2 (PEFT) | 21381MiB | 121M |
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU,需要使用8张GPU) | 20G |
在开始微调之前,请你先安装`basic_demo`中的依赖,同时您需要安装本目录下的依赖项:
```bash
pip install -r requirements.txt
```
## 多轮对话格式
多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`
对于数据文件,样例采用如下格式
如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
```json
[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>",
"tools": [
{
"name": "<tool name>",
"args": {
"<arg name>": "<arg value>"
}
}
// Add more tools if needed
]
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// If Tool Using
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "observation",
"content": "<observation prompt text>"
},
{
"role": "assistant",
"content": "<assistant response observation>"
},
// Multi_turns
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
]
```
这里是一个不带有工具的例子:
```
{"messages": [{"role": "user", "content": "类型#裤*材质#牛仔布*风格#性感"}, {"role": "assistant", "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"}]}
```
这是一个带有工具调用的例子:
```
{"messages": [{"role": "system", "content": "", "tools": [{"type": "function", "function": {"name": "get_recommended_books", "description": "Get recommended books based on user's interests", "parameters": {"type": "object", "properties": {"interests": {"type": "array", "items": {"type": "string"}, "description": "The interests to recommend books for"}}, "required": ["interests"]}}}]}, {"role": "user", "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."}, {"role": "assistant", "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"}, {"role": "observation", "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"}, {"role": "assistant", "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."}]}
```
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user`
角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
- `tools` 字段为可选字段,若存在 `tools` 字段,其必须出现在 `system`
角色之后,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `tools` 字段。当 `tools` 字段存在时,`system`
角色必须存在并且 `content` 字段为空。
## 配置文件
微调配置文件位于 `config` 目录下,包括以下文件:
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。
2. `lora.yaml / ptuning_v2.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
+ data_config 部分
+ train_file: 训练数据集的文件路径。
+ val_file: 验证数据集的文件路径。
+ test_file: 测试数据集的文件路径。
+ num_proc: 在加载数据时使用的进程数量。
+ max_input_length: 输入序列的最大长度。
+ max_output_length: 输出序列的最大长度。
+ training_args 部分
+ output_dir: 用于保存模型和其他输出的目录。
+ max_steps: 训练的最大步数。
+ per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。
+ dataloader_num_workers: 加载数据时使用的工作线程数量。
+ remove_unused_columns: 是否移除数据中未使用的列。
+ save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
+ save_steps: 每隔多少步保存一次模型。
+ log_level: 日志级别(如 info)。
+ logging_strategy: 日志记录策略。
+ logging_steps: 每隔多少步记录一次日志。
+ per_device_eval_batch_size: 每个设备的评估批次大小。
+ evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
+ eval_steps: 每隔多少步进行一次评估。
+ predict_with_generate: 是否使用生成模式进行预测。
+ generation_config 部分
+ max_new_tokens: 生成的最大新 token 数量。
+ peft_config 部分
+ peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。
+ task_type: 任务类型,这里是因果语言模型 (不要改动)。
+ Lora 参数:
+ r: LoRA 的秩。
+ lora_alpha: LoRA 的缩放因子。
+ lora_dropout: 在 LoRA 层使用的 dropout 概率。
+ P-TuningV2 参数:
+ num_virtual_tokens: 虚拟 token 的数量。
+ num_attention_heads: 2: P-TuningV2 的注意力头数(不要改动)。
+ token_dim: 256: P-TuningV2 的 token 维度(不要改动)。
## 开始微调
通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`
```shell
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b configs/lora.yaml
```
通过以下代码执行 **单机单卡** 运行。
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
```
## 从保存点进行微调
如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式:
1. `yes`, 自动从最后一个保存的 Checkpoint开始训练
2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练
例如,这就是一个从最后一个保存点继续微调的示例代码
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
```
## 使用微调后的模型
### 在 inference.py 中验证微调后的模型
您可以在 `finetune_demo/inference.py` 中使用我们的微调后的模型,仅需要一行代码就能简单的进行测试。
```shell
python inference.py your_finetune_path
```
这样,得到的回答就微调后的回答了。
### 在本仓库的其他 demo 或者外部仓库使用微调后的模型
您可以在任何一个 demo 内使用我们的 `LORA` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。
1. 使用`finetune_demo/inference.py`中读入模型的方式替换 demo 中读入模型的方式。
> 请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在`adapter_config.json`
> 中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改`adapter_config.json`中`base_model_name_or_path`的路径。
```python
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
```
2. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为`/path/to/finetune_adapter_model`
,原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`
3. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。
## 参考文献
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
\ No newline at end of file
# GLM-4-9B Chat dialogue model fine-tuning
In this demo, you will experience how to fine-tune the GLM-4-9B-Chat open source model (visual understanding model is
not supported). Please strictly follow the steps in the document to avoid unnecessary errors.
## Hardware check
**The data in this document are tested in the following hardware environment. The actual operating environment
requirements and the video memory occupied by the operation are slightly different. Please refer to the actual operating
environment. **
Test hardware information:
+ OS: Ubuntu 22.04
+ Memory: 512GB
+ Python: 3.12.3
+ CUDA Version: 12.3
+ GPU Driver: 535.104.05
+ GPU: NVIDIA A100-SXM4-80GB * 8
| Fine-tuning solution | Video memory usage | Weight save point size |
|----------------------|----------------------------------------------|------------------------|
| lora (PEFT) | 21531MiB | 17M |
| p-tuning v2 (PEFT) | 21381MiB | 121M |
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU, 8 GPUs are required) | 20G |
Before starting fine-tuning, please install the dependencies in `basic_demo` first. You also need to install the
dependencies in this directory:
```bash
pip install -r requirements.txt
```
## Multi-round dialogue format
The multi-round dialogue fine-tuning example uses the GLM-4 dialogue format convention, adding different `loss_mask` to
different roles to calculate `loss` for multiple rounds of replies in one calculation.
For data files, the sample uses the following format:
```json
[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>",
"tools": [
{
"name": "<tool name>",
"args": {
"<arg name>": "<arg value>"
}
}
// Add more tools if needed
]
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// If Tool Using
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "observation",
"content": "<observation prompt text>"
},
{
"role": "assistant",
"content": "<assistant response observation>"
},
// Multi_turns
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
]
```
This is a sample without tools:
```
{"messages": [{"role": "user", "content": "类型#裤*材质#牛仔布*风格#性感"}, {"role": "assistant", "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"}]}
```
This is a sample with tools:
```
{"messages": [{"role": "system", "content": "", "tools": [{"type": "function", "function": {"name": "get_recommended_books", "description": "Get recommended books based on user's interests", "parameters": {"type": "object", "properties": {"interests": {"type": "array", "items": {"type": "string"}, "description": "The interests to recommend books for"}}, "required": ["interests"]}}}]}, {"role": "user", "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."}, {"role": "assistant", "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"}, {"role": "observation", "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"}, {"role": "assistant", "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."}]}
```
- The `system` role is optional, but if it exists, it must appear before the `user` role, and a complete conversation
data (whether single-round or multi-round conversation) can only have one `system` role.
- The `tools` field is optional. If it exists, it must appear after the `system` role, and a complete conversation
data (whether single-round or multi-round conversation) can only have one `tools` field. When the `tools` field
exists, the `system` role must exist and the `content` field is empty.
## Configuration file
The fine-tuning configuration file is located in the `config` directory, including the following files:
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file.
2. `lora.yaml / ptuning_v2
3. .yaml / sft.yaml`: Configuration files for different modes of models, including model parameters, optimizer
parameters, training parameters, etc. Some important parameters are explained as follows:
+ data_config section
+ train_file: File path of training dataset.
+ val_file: File path of validation dataset.
+ test_file: File path of test dataset.
+ num_proc: Number of processes to use when loading data.
+ max_input_length: Maximum length of input sequence.
+ max_output_length: Maximum length of output sequence.
+ training_args section
+ output_dir: Directory for saving model and other outputs.
+ max_steps: Maximum number of training steps.
+ per_device_train_batch_size: Training batch size per device (such as GPU).
+ dataloader_num_workers: Number of worker threads to use when loading data.
+ remove_unused_columns: Whether to remove unused columns in data.
+ save_strategy: Model saving strategy (for example, how many steps to save).
+ save_steps: How many steps to save the model.
+ log_level: Log level (such as info).
+ logging_strategy: logging strategy.
+ logging_steps: how many steps to log at.
+ per_device_eval_batch_size: per-device evaluation batch size.
+ evaluation_strategy: evaluation strategy (e.g. how many steps to evaluate at).
+ eval_steps: how many steps to evaluate at.
+ predict_with_generate: whether to use generation mode for prediction.
+ generation_config section
+ max_new_tokens: maximum number of new tokens to generate.
+ peft_config section
+ peft_type: type of parameter tuning to use (supports LORA and PREFIX_TUNING).
+ task_type: task type, here is causal language model (don't change).
+ Lora parameters:
+ r: rank of LoRA.
+ lora_alpha: scaling factor of LoRA.
+ lora_dropout: dropout probability to use in LoRA layer.
+ P-TuningV2 parameters:
+ num_virtual_tokens: the number of virtual tokens.
+ num_attention_heads: 2: the number of attention heads of P-TuningV2 (do not change).
+ token_dim: 256: the token dimension of P-TuningV2 (do not change).
## Start fine-tuning
Execute **single machine multi-card/multi-machine multi-card** run through the following code, which uses `deepspeed` as
the acceleration solution, and you need to install `deepspeed`.
```shell
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b configs/lora.yaml
```
Execute **single machine single card** run through the following code.
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
```
## Fine-tune from a saved point
If you train as described above, each fine-tuning will start from the beginning. If you want to fine-tune from a
half-trained model, you can add a fourth parameter, which can be passed in two ways:
1. `yes`, automatically start training from the last saved Checkpoint
2. `XX`, breakpoint number, for example `600`, start training from Checkpoint 600
For example, this is an example code to continue fine-tuning from the last saved point
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
```
## Use the fine-tuned model
### Verify the fine-tuned model in inference.py
You can Use our fine-tuned model in `finetune_demo/inference.py`, and you can easily test it with just one line of code.
```shell
python inference.py your_finetune_path
```
In this way, the answer you get is the fine-tuned answer.
### Use the fine-tuned model in other demos in this repository or external repositories
You can use our `LORA` and fully fine-tuned models in any demo. This requires you to modify the code yourself according
to the following tutorial.
1. Replace the way to read the model in the demo with the way to read the model in `finetune_demo/inference.py`.
> Please note that for LORA and P-TuningV2, we did not merge the trained models, but recorded the fine-tuned path
> in `adapter_config.json`
> If the location of your original model changes, you should modify the path of `base_model_name_or_path`
> in `adapter_config.json`.
```python
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
```
2. Read the fine-tuned model. Please note that you should use the location of the fine-tuned model. For example, if your
model location is `/path/to/finetune_adapter_model`
and the original model address is `path/to/base_model`, you should use `/path/to/finetune_adapter_model`
as `model_dir`.
3. After completing the above operations, you can use the fine-tuned model normally. Other calling methods remain
unchanged.
## Reference
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
\ No newline at end of file
{
"train_micro_batch_size_per_gpu": "auto",
"zero_allow_untested_optimizer": true,
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"contiguous_gradients": true,
"overlap_comm": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
\ No newline at end of file
data_config:
train_file: train.jsonl
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 512
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-4
# settings for data loading
per_device_train_batch_size: 1
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 4
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
#deepspeed: ds_zero_2.json
peft_config:
peft_type: LORA
task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1
data_config:
train_file: train.jsonl
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 128
max_output_length: 128
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-4
# settings for data loading
per_device_train_batch_size: 4
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 500
# settings for evaluation
per_device_eval_batch_size: 16
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
#deepspeed: ds_zero_3.json
peft_config:
peft_type: PREFIX_TUNING
task_type: CAUSAL_LM
num_virtual_tokens: 512
num_attention_heads: 2
token_dim: 256
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment