Commit d0572507 authored by lvzhen's avatar lvzhen
Browse files

Deleted basic_demo/cli_demo.py, basic_demo/cli_demo_bad_word_ids.py,...

Deleted basic_demo/cli_demo.py, basic_demo/cli_demo_bad_word_ids.py, basic_demo/infer_test.py, basic_demo/utils.py, basic_demo/vocab.txt, basic_demo/web_demo.py, basic_demo/web_demo2.py, composite_demo/.streamlit/config.toml, composite_demo/assets/demo.png, composite_demo/assets/emojis.png, composite_demo/assets/heart.png, composite_demo/assets/tool.png, composite_demo/README.md, composite_demo/README_en.md, composite_demo/client.py, composite_demo/conversation.py, composite_demo/demo_chat.py, composite_demo/demo_ci.py, composite_demo/demo_tool.py, composite_demo/main.py, composite_demo/requirements.txt, composite_demo/tool_registry.py, cookbook/data/toutiao_cat_data_example.txt, cookbook/accurate_prompt.ipynb, cookbook/finetune_muti_classfication.ipynb, finetune_basemodel_demo/scripts/finetune_lora.sh, finetune_basemodel_demo/scripts/formate_alpaca2jsonl.py, finetune_basemodel_demo/README.md, finetune_basemodel_demo/arguments.py, finetune_basemodel_demo/finetune.py, finetune_basemodel_demo/inference.py, finetune_basemodel_demo/preprocess_utils.py, finetune_basemodel_demo/requirements.txt, finetune_basemodel_demo/trainer.py, finetune_chatmodel_demo/AdvertiseGen/dev.json, finetune_chatmodel_demo/AdvertiseGen/train.json, finetune_chatmodel_demo/configs/deepspeed.json, finetune_chatmodel_demo/formatted_data/advertise_gen.jsonl, finetune_chatmodel_demo/formatted_data/tool_alpaca.jsonl, finetune_chatmodel_demo/scripts/finetune_ds.sh, finetune_chatmodel_demo/scripts/finetune_ds_multiturn.sh, finetune_chatmodel_demo/scripts/finetune_pt.sh, finetune_chatmodel_demo/scripts/finetune_pt_multiturn.sh, finetune_chatmodel_demo/scripts/format_advertise_gen.py, finetune_chatmodel_demo/scripts/format_tool_alpaca.py, finetune_chatmodel_demo/README.md, finetune_chatmodel_demo/arguments.py, finetune_chatmodel_demo/finetune.py, finetune_chatmodel_demo/inference.py, finetune_chatmodel_demo/preprocess_utils.py, finetune_chatmodel_demo/requirements.txt, finetune_chatmodel_demo/train_data.json, finetune_chatmodel_demo/trainer.py, langchain_demo/Tool/Calculator.py, langchain_demo/Tool/Calculator.yaml, langchain_demo/Tool/Weather.py, langchain_demo/Tool/arxiv_example.yaml, langchain_demo/Tool/weather.yaml, langchain_demo/ChatGLM3.py, langchain_demo/README.md, langchain_demo/main.py, langchain_demo/requirements.txt, langchain_demo/utils.py, media/GLM.png, media/cli.png, media/transformers.jpg, openai_api_demo/openai_api.py, openai_api_demo/openai_api_request.py, openai_api_demo/requirements.txt, openai_api_demo/utils.py, resources/WECHAT.md, resources/cli-demo.png, resources/code_en.gif, resources/heart.png, resources/tool.png, resources/tool_en.png, resources/web-demo.gif, resources/web-demo2.gif, resources/web-demo2.png, resources/wechat.jpg, tool_using/README.md, tool_using/README_en.md, tool_using/cli_demo_tool.py, tool_using/openai_api_demo.py, tool_using/requirements.txt, tool_using/test.py, tool_using/tool_register.py, DEPLOYMENT.md, DEPLOYMENT_en.md, Dockerfile, MODEL_LICENSE, PROMPT.md, PROMPT_en.md, README.md, README_en.md, README_old.md, lvzhen.log, model.properties, requirements.txt files
parent d7be7b1c
from __future__ import annotations
from collections.abc import Iterable
import os
from typing import Any, Protocol
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
from conversation import Conversation
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
PT_PATH = os.environ.get('PT_PATH', None)
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# for Mac Computer like M1
# You Need Use Pytorch compiled with Metal
# DEVICE = 'mps'
# for AMD gpu likes MI100 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with ROCm
# DEVICE = 'cuda'
# for Intel gpu likes A770 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with oneDNN and install intel-extension-for-pytorch
# import intel_extension_for_pytorch as ipex
# DEVICE = 'xpu'
# for Moore Threads gpu like MTT S80 (Not Official Steady Support yet)
# You Need Use Pytorch compiled with Musa
# DEVICE = 'musa'
@st.cache_resource
def get_client() -> Client:
client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH, DEVICE)
return client
class Client(Protocol):
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
...
def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
repetition_penalty=1.0, length_penalty=1.0, num_beams=1,
logits_processor=None, return_past_key_values=False, **kwargs):
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
"num_beams": num_beams,
**kwargs
}
print(gen_kwargs)
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
inputs = tokenizer.build_chat_input(query, role=role)
inputs = inputs.to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
print("input_shape>", inputs['input_ids'].shape)
input_sequence_length = inputs['input_ids'].shape[1]
if max_length < input_sequence_length <= self.config.seq_length:
yield "Current input sequence length {} exceeds sequence length set in generation parameters {}. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
input_sequence_length, max_length, self.config.seq_length
), history
return
if input_sequence_length > self.config.seq_length:
yield "Current input sequence length {} exceeds maximum model sequence length {}. Unable to generate tokens.".format(
input_sequence_length, self.config.seq_length
), history
return
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
**gen_kwargs):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
new_history = history
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
class HFClient(Client):
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | None = None, DEVICE = 'cpu'):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
if pt_checkpoint is not None:
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, config=config)
prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
self.model = self.model.to(DEVICE).eval() if 'cuda' in DEVICE else self.model.float().to(DEVICE).eval()
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
chat_history = [{
'role': 'system',
'content': system if not tools else TOOL_PROMPT,
}]
if tools:
chat_history[0]['tools'] = tools
for conversation in history[:-1]:
chat_history.append({
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
'content': conversation.content,
})
query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
text = ''
for new_text, _ in stream_chat(self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
word = new_text.removeprefix(text)
word_stripped = word.strip()
text = new_text
yield TextGenerationStreamResponse(
generated_text=text,
token=Token(
id=0,
logprob=0,
text=word,
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
)
)
from dataclasses import dataclass
from enum import auto, Enum
import json
from PIL.Image import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n'
class Role(Enum):
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
INTERPRETER = auto()
OBSERVATION = auto()
def __str__(self):
match self:
case Role.SYSTEM:
return "<|system|>"
case Role.USER:
return "<|user|>"
case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:
return "<|assistant|>"
case Role.OBSERVATION:
return "<|observation|>"
# Get the message block for the given role
def get_message(self):
# Compare by value here, because the enum object in the session state
# is not the same as the enum cases here, due to streamlit's rerunning
# behavior.
match self.value:
case Role.SYSTEM.value:
return
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
case Role.INTERPRETER.value:
return st.chat_message(name="interpreter", avatar="assistant")
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="user")
case _:
st.error(f'Unexpected role: {self}')
@dataclass
class Conversation:
role: Role
content: str
tool: str | None = None
image: Image | None = None
def __str__(self) -> str:
print(self.role, self.content, self.tool)
match self.role:
case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:
return f'{self.role}\n{self.content}'
case Role.TOOL:
return f'{self.role}{self.tool}\n{self.content}'
case Role.INTERPRETER:
return f'{self.role}interpreter\n{self.content}'
# Human readable format
def get_text(self) -> str:
text = postprocess_text(self.content)
match self.role.value:
case Role.TOOL.value:
text = f'Calling tool `{self.tool}`:\n{text}'
case Role.INTERPRETER.value:
text = f'{text}'
case Role.OBSERVATION.value:
text = f'Observation:\n```\n{text}\n```'
return text
# Display as a markdown block
def show(self, placeholder: DeltaGenerator | None=None) -> str:
if placeholder:
message = placeholder
else:
message = self.role.get_message()
if self.image:
message.image(self.image)
else:
text = self.get_text()
message.markdown(text)
def preprocess_text(
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
) -> str:
if tools:
tools = json.dumps(tools, indent=4, ensure_ascii=False)
prompt = f"{Role.SYSTEM}\n"
prompt += system if not tools else TOOL_PROMPT
if tools:
tools = json.loads(tools)
prompt += json.dumps(tools, ensure_ascii=False)
for conversation in history:
prompt += f'{conversation}'
prompt += f'{Role.ASSISTANT}\n'
return prompt
def postprocess_text(text: str) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
return text.strip()
\ No newline at end of file
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
MAX_LENGTH = 8192
client = get_client()
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str, repetition_penalty: float):
placeholder = st.empty()
with placeholder.container():
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
append_conversation(Conversation(Role.USER, prompt_text), history)
input_text = preprocess_text(
system_prompt,
tools=None,
history=history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.empty()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
for response in client.generate_stream(
system_prompt,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(Role.USER)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
break
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
import base64
from io import BytesIO
import os
from pprint import pprint
import queue
import re
from subprocess import PIPE
import jupyter_client
from PIL import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo')
SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。'
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
client = get_client()
class CodeKernel(object):
def __init__(self,
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
self.python_path = python_path
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose
if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_manager.connection_file))
if verbose:
pprint(self.kernel_manager.get_connection_info())
# Initialize the code kernel
self.kernel = self.kernel_manager.blocking_client()
# self.kernel.load_connection_file()
self.kernel.start_channels()
print("Code kernel started.")
def execute(self, code):
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
while True:
msg_out = io_msg_content
### Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
break
except queue.Empty:
break
return shell_msg, msg_out
except Exception as e:
print(e)
return None
def execute_interactive(self, code, verbose=False):
shell_msg = self.kernel.execute_interactive(code)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def inspect(self, code, verbose=False):
msg_id = self.kernel.inspect(code)
shell_msg = self.kernel.get_shell_msg(timeout=30)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error':
try:
error_msg = msg['content']['traceback']
except:
try:
error_msg = msg['content']['traceback'][-1].strip()
except:
error_msg = "Traceback Error"
if verbose:
print("Error: ", error_msg)
return error_msg
return None
def check_msg(self, msg, verbose=False):
status = msg['content']['status']
if status == 'ok':
if verbose:
print("Execution succeeded.")
elif status == 'error':
for line in msg['content']['traceback']:
if verbose:
print(line)
def shutdown(self):
# Shutdown the backend kernel
self.kernel_manager.shutdown_kernel()
print("Backend kernel shutdown.")
# Shutdown the code kernel
self.kernel.shutdown()
print("Code kernel shutdown.")
def restart(self):
# Restart the backend kernel
self.kernel_manager.restart_kernel()
# print("Backend kernel restarted.")
def interrupt(self):
# Interrupt the backend kernel
self.kernel_manager.interrupt_kernel()
# print("Backend kernel interrupted.")
def is_alive(self):
return self.kernel.is_alive()
def b64_2_img(data):
buff = BytesIO(base64.b64decode(data))
return Image.open(buff)
def clean_ansi_codes(input_string):
ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', input_string)
def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
code = code.replace("<|assistant|>interpreter", "")
code = code.replace("<|assistant|>", "")
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
res_type = "text"
res = output['text']
elif 'data' in output:
for key in output['data']:
if 'text/plain' in key:
res_type = "text"
res = output['data'][key]
elif 'image/png' in key:
res_type = "image"
res = output['data'][key]
break
if res_type == "image":
return res_type, b64_2_img(res)
elif res_type == "text" or res_type == "traceback":
res = res
return res_type, res
@st.cache_resource
def get_kernel():
kernel = CodeKernel()
return kernel
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
if 'ci_history' not in st.session_state:
st.session_state.ci_history = []
history: list[Conversation] = st.session_state.ci_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
role = Role.USER
append_conversation(Conversation(role, prompt_text), history)
input_text = preprocess_text(
SYSTEM_PROMPT,
None,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=SYSTEM_PROMPT,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="interpreter", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
continue
case '<|observation|>':
code = extract_code(output_text)
print("Code:", code)
display_text = output_text.split('interpreter')[-1].strip()
append_conversation(Conversation(
Role.INTERPRETER,
postprocess_text(display_text),
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
output_text = ''
with markdown_placeholder:
with st.spinner('Executing code...'):
try:
res_type, res = execute(code, get_kernel())
except Exception as e:
st.error(f'Error when executing code: {e}')
return
print("Received:", res_type, res)
if res_type == 'text' and len(res) > TRUNCATE_LENGTH:
res = res[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION,
'[Image]' if res_type == 'image' else postprocess_text(res),
tool=None,
image=res if res_type == 'image' else None,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
break
output_text += response.token.text
display_text = output_text.split('interpreter')[-1].strip()
markdown_placeholder.markdown(postprocess_text(display_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
import re
import yaml
from yaml import YAMLError
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
EXAMPLE_TOOL = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
}
client = get_client()
def tool_call(*args, **kwargs) -> dict:
print("=== Tool call:")
print(args)
print(kwargs)
st.session_state.calling_tool = True
return kwargs
def yaml_to_dict(tools: str) -> list[dict] | None:
try:
return yaml.safe_load(tools)
except YAMLError:
return None
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
manual_mode = st.toggle('Manual mode',
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
if manual_mode:
with st.expander('Tools'):
tools = st.text_area(
'Define your tools in YAML format here:',
yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
height=400,
)
tools = yaml_to_dict(tools)
if not tools:
st.error('YAML format error in tools definition')
else:
tools = get_tools()
if 'tool_history' not in st.session_state:
st.session_state.tool_history = []
if 'calling_tool' not in st.session_state:
st.session_state.calling_tool = False
history: list[Conversation] = st.session_state.tool_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
append_conversation(Conversation(role, prompt_text), history)
st.session_state.calling_tool = False
input_text = preprocess_text(
None,
tools,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=None,
tools=tools,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
output_text = ''
message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
continue
case '<|observation|>':
tool, *call_args_text = output_text.strip().split('\n')
call_args_text = '\n'.join(call_args_text)
append_conversation(Conversation(
Role.TOOL,
postprocess_text(output_text),
tool,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
try:
code = extract_code(call_args_text)
args = eval(code, {'tool_call': tool_call}, {})
except:
st.error('Failed to parse tool call')
return
output_text = ''
if manual_mode:
st.info('Please provide tool call results below:')
return
else:
with markdown_placeholder:
with st.spinner(f'Calling tool {tool}...'):
observation = dispatch_tool(tool, args)
if len(observation) > TRUNCATE_LENGTH:
observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION, observation
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
st.session_state.calling_tool = False
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
return
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
from enum import Enum
import streamlit as st
st.set_page_config(
page_title="ChatGLM3 Demo",
page_icon=":robot:",
layout='centered',
initial_sidebar_state='expanded',
)
import demo_chat, demo_ci, demo_tool
DEFAULT_SYSTEM_PROMPT = '''
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
'''.strip()
# Set the title of the demo
st.title("ChatGLM3 Demo")
# Add your custom text here, with smaller font size
st.markdown("<sub>智谱AI 公开在线技术文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof </sub> \n\n <sub> 更多 ChatGLM3-6B 的使用方法请参考文档。</sub>", unsafe_allow_html=True)
class Mode(str, Enum):
CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑‍💻 Code Interpreter'
with st.sidebar:
top_p = st.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.slider(
'temperature', 0.0, 1.5, 0.95, step=0.01
)
repetition_penalty = st.slider(
'repetition_penalty', 0.0, 2.0, 1.2, step=0.01
)
system_prompt = st.text_area(
label="System Prompt (Only for chat mode)",
height=300,
value=DEFAULT_SYSTEM_PROMPT,
)
prompt_text = st.chat_input(
'Chat with ChatGLM3!',
key='chat_input',
)
tab = st.radio(
'Mode',
[mode.value for mode in Mode],
horizontal=True,
label_visibility='hidden',
)
match tab:
case Mode.CHAT:
demo_chat.main(top_p, temperature, system_prompt, prompt_text, repetition_penalty)
case Mode.TOOL:
demo_tool.main(top_p, temperature, prompt_text, repetition_penalty)
case Mode.CI:
demo_ci.main(top_p, temperature, prompt_text, repetition_penalty)
case _:
st.error(f'Unexpected tab: {tab}')
huggingface_hub
ipykernel
ipython
jupyter_client
pillow
sentencepiece
streamlit
tokenizers
pyyaml
requests
\ No newline at end of file
import copy
import inspect
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated
_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = {}
def register_tool(func: callable):
tool_name = func.__name__
tool_description = inspect.getdoc(func).strip()
python_params = inspect.signature(func).parameters
tool_params = []
for name, param in python_params.items():
annotation = param.annotation
if annotation is inspect.Parameter.empty:
raise TypeError(f"Parameter `{name}` missing type annotation")
if get_origin(annotation) != Annotated:
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
typ, (description, required) = annotation.__origin__, annotation.__metadata__
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
if not isinstance(description, str):
raise TypeError(f"Description for `{name}` must be a string")
if not isinstance(required, bool):
raise TypeError(f"Required for `{name}` must be a bool")
tool_params.append({
"name": name,
"description": description,
"type": typ,
"required": required
})
tool_def = {
"name": tool_name,
"description": tool_description,
"params": tool_params
}
print("[registered tool] " + pformat(tool_def))
_TOOL_HOOKS[tool_name] = func
_TOOL_DESCRIPTIONS[tool_name] = tool_def
return func
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
if tool_name not in _TOOL_HOOKS:
return f"Tool `{tool_name}` not found. Please use a provided tool."
tool_call = _TOOL_HOOKS[tool_name]
try:
ret = tool_call(**tool_params)
except:
ret = traceback.format_exc()
return str(ret)
def get_tools() -> dict:
return copy.deepcopy(_TOOL_DESCRIPTIONS)
# Tool Definitions
@register_tool
def random_number_generator(
seed: Annotated[int, 'The random seed used by the generator', True],
range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
"""
if not isinstance(seed, int):
raise TypeError("Seed must be an integer")
if not isinstance(range, tuple):
raise TypeError("Range must be a tuple")
if not isinstance(range[0], int) or not isinstance(range[1], int):
raise TypeError("Range must be a tuple of integers")
import random
return random.Random(seed).randint(*range)
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the current weather for `city_name`
"""
if not isinstance(city_name, str):
raise TypeError("City name must be a string")
key_selection = {
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
}
import requests
try:
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
resp.raise_for_status()
resp = resp.json()
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
except:
import traceback
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
return str(ret)
if __name__ == "__main__":
print(dispatch_tool("get_weather", {"city_name": "beijing"}))
print(get_tools())
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 使用精确的提示词完成符合格式输出的文本分类任务\n",
"\n",
"在本章节中,我们将带领开发者体验在 `文本分类` 任务中,在不进行任何模型微调的情况下,使用具有微小差异的提示词,并使用 `ChatGLM3-6B` 模型来完成符合格式输出的文本分类任务。\n",
"\n",
"我们使用 [新闻标题分类](https://github.com/fateleak/toutiao-multilevel-text-classfication-dataset) 任务来体验模型的表现。这是一个经典的文本分类任务,我们将使用 `新闻信息` 作为输入,模型需要预测出这个标题属于哪个类别。\n",
"\n",
"由于 `ChatGLM3-6B` 强大的能力,我们可以直接使用 `新闻标题` 作为输入,而不需要额外的信息,也不需要进行任何模型微调,就能完成这个任务。我们的目标是,让模型能够成功输出原始数据集中的15种类别中的一种作为分类结果,而且不能输入任何冗余的对话。\n",
"\n",
"在本章节中,用户将直观的对比两种不同细粒度的提示词下对模型分类造成的影响。\n",
"\n",
"## 硬件要求\n",
"本实践手册需要使用 FP16 精度的模型进行推理,因此,我们推荐使用至少 16GB 显存的 英伟达 GPU 来完成本实践手册。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"引入必要的库"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"from transformers import AutoModel, AutoTokenizer\n",
"import torch"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:03.002351Z",
"end_time": "2023-11-23T19:23:03.016701Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"设置好对应的参数,以保证模型推理的公平性。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"max_new_tokens = 1024\n",
"temperature = 0.1\n",
"top_p = 0.9\n",
"device = \"cuda\"\n",
"model_path_chat = \"/Models/chatglm3-6b\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:03.004850Z",
"end_time": "2023-11-23T19:23:03.047154Z"
}
}
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Loading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "d6468f7889994e638ac754fde95b6e58"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer_path_chat = model_path_chat\n",
"tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_chat, trust_remote_code=True, encode_special_tokens=True)\n",
"model = AutoModel.from_pretrained(model_path_chat, load_in_8bit=False, trust_remote_code=True).to(device)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:03.047154Z",
"end_time": "2023-11-23T19:23:13.086148Z"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"def answer(prompt):\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
" response = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=max_new_tokens, history=None,\n",
" temperature=temperature,\n",
" top_p=top_p, do_sample=True)\n",
" response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
" answer = tokenizer.decode(response, skip_special_tokens=True)\n",
" return answer"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:13.086148Z",
"end_time": "2023-11-23T19:23:13.086148Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"在本样例中,模型应该输出的标准答案应该为:\n",
"'news_sports'"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [],
"source": [
"PROMPT1 = \"\"\"\n",
"<|system|>\n",
"你是一个专业的新闻专家,请根据我提供的新闻信息,包括新闻标题,新闻关键词等信息,你需要在这些类中选择其中一个,它们分别是:\n",
"news_story\n",
"news_culture\n",
"news_sports\n",
"news_finance\n",
"news_house\n",
"news_car\n",
"news_edu\n",
"news_tech\n",
"news_military\n",
"news_travel\n",
"news_world\n",
"stock\n",
"news_agriculture\n",
"news_game\n",
"这是我的信息:\n",
"<|user|>\n",
"新闻标题: 女乒今天排兵布阵不合理,丁宁昨天刚打硬仗,今天应该打第一单打,你认为呢?\n",
"新闻关键词: 无\n",
"<|assistant|>\n",
"\"\"\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:13.086148Z",
"end_time": "2023-11-23T19:23:13.095864Z"
}
}
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [
{
"data": {
"text/plain": "' 新闻类型:体育新闻\\n 新闻关键词:女乒,排兵布阵,丁宁,硬仗,第一单打'"
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer(PROMPT1)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:13.095864Z",
"end_time": "2023-11-23T19:23:14.034154Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"很明显,使用这个提示词没有完成要求。我们更换一个提示词,经过优化后,看是否能达标。\n",
"\n",
"我们为提示词中设定了更多的限定词汇,并用Markdown语法规范化了输出的格式。修改后的提示词如下:"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [],
"source": [
"PROMPT2 = \"\"\"\n",
"<|system|>\n",
"请根据我提供的新闻信息,格式为\n",
"```\n",
"新闻标题: xxx\n",
"新闻关键词: xxx\n",
"```\n",
"你要对每一行新闻类别进行分类并告诉我结果,不要返回其他信息和多于的文字,这些类别是:\n",
"news_story\n",
"news_culture\n",
"news_sports\n",
"news_finance\n",
"news_house\n",
"news_car\n",
"news_edu\n",
"news_tech\n",
"news_military\n",
"news_travel\n",
"news_world\n",
"stock\n",
"news_agriculture\n",
"news_game\n",
"我将为你提供一些新闻标题和关键词,你需要根据这些信息,对这些新闻进行分类,每条一行,格式为\n",
"```\n",
"类别中的其中一个\n",
"```\n",
"不要返回其他内容,例如新闻标题,新闻关键词等等,只需要返回分类结果即可\n",
"<|user|>\n",
"新闻标题: 女乒今天排兵布阵不合理,丁宁昨天刚打硬仗,今天应该打第一单打,你认为呢?\n",
"新闻关键词: 无\n",
"<|assistant|>\n",
"\"\"\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:14.044595Z",
"end_time": "2023-11-23T19:23:14.044595Z"
}
}
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [
{
"data": {
"text/plain": "'news_sports'"
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer(PROMPT2)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:14.044595Z",
"end_time": "2023-11-23T19:23:14.217801Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"这一次,模型成功给出了理想的答案。我们结束实操训练,删除模型并释放显存。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [],
"source": [
"del model\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-11-23T19:23:14.217801Z",
"end_time": "2023-11-23T19:23:14.373137Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 总结\n",
"在本实践手册中,我们让开发者体验了使用不同提示词下,`ChatGLM3-6B` 模型在 `新闻标题分类` 任务中的表现。\n",
"1. 模型在两次提示中都能正确的分类,验证了模型底座的能力。\n",
"2. 在使用第二个提示词时,模型的输出符合格式要求,而且没有冗余的对话。符合我们的要求。\n",
"3. 第二个提示词使用了一定的表示限定的词汇和更准确的任务要求,包括规定了返回格式,返回的行书等的。\n",
"\n",
"因此,通过上述内容,我们可以发现:\n",
"对于有格式要求的任务,可以使用更具有格式化的提示词来完成任务。同时,使用更低的 `temperature` 和更高的 `top_p` 可以提高模型的输出质量。\n",
"\n",
"模型在训练中大量使用 Markdown格式。因此,用 ``` 符号来规范提示词将能提升模型输出格式化数据的准确率。"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
This diff is collapsed.
This diff is collapsed.
# ChatGLM3-6B-base 微调示例
本目录提供 ChatGLM3-6B-base 模型的微调示例,目前,仅包含了Lora微调。
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b-base` 字段均应替换为相应地址以从本地加载模型。
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖
```bash
pip install requirements.txt
```
## 多轮对话格式
`base`模型不具备对话能力,仅能够生成单轮回复。如果你希望使用多轮对话模型,使用`Chat`模型进行微调。
## 数据集要求
格式上,请使用`alpaca`数据集。
```bash
{"context": "hello", "target": "hi,I am ChatGLM3"}
```
其中,`context`是对话的上文,也就是模型的输入,`target`是对话的下文,也就是模型的输出。
## 微调模型
以下脚本提供了微调模型的参考方式。
```bash
./scripts/finetune_lora.sh # 使用Lora微调
```
### 提示
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为
```log
Sanity Check >>>>>>>>>>>>>
'[gMASK]': 64790 -> -100
'sop': 64792 -> -100
'<|system|>': 64794 -> -100
'': 30910 -> -100
'\n': 13 -> -100
'Answer': 20115 -> -100
'the': 267 -> -100
'following': 1762 -> -100
...
'know': 683 -> -100
'the': 267 -> -100
'response': 3010 -> -100
'details': 3296 -> -100
'.': 30930 -> -100
'<|assistant|>': 64796 -> -100
'': 30910 -> 30910
'\n': 13 -> 13
'I': 307 -> 307
'need': 720 -> 720
'to': 289 -> 289
'use': 792 -> 792
...
'': 0 -> -100
'': 0 -> -100 (有若干个)
<<<<<<<<<<<<< Sanity Check
```
字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。
2. 参考显存用量
- 按照官方脚本的默认参数运行,每一张显卡占用显存为 `23GB`
3. 若尝试后发现显存不足,可以考虑
- 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS`
- 尝试降低 `MAX_SEQ_LEN`,但是这可能会影响模型的性能
## 注意事项
+ 基座模型不具备对话能力,仅能够生成单轮回复。如果你希望使用多轮对话模型,使用Chat模型进行微调。
+ 请注意,运行本脚本,你还需要安装本目录下的 `requirements.txt` 中的所有内容。
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
lora_checkpoint: str = field(
default=None, metadata={"help": "Path to lora checkpoints"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
)
},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={
"help": (
"An optional parameter specifying the number of bits used for quantization. "
"Quantization is a process that reduces the model size by limiting the number of "
"bits that represent each weight in the model. A lower number of bits can reduce "
"the model size and speed up inference, but might also decrease model accuracy. "
"If not set (None), quantization is not applied."
)
},
)
lora_rank: Optional[int] = field(
default=8,
metadata={
"help": (
"balancing between complexity and model flexibility. A higher rank allows more "
"complex adaptations but increases the number of parameters and computational cost."
)
},
)
lora_alpha: Optional[float] = field(
default=32,
metadata={
"help": (
"A higher value results in more significant adjustments, potentially improving adaptation to new tasks or data, "
"but might also risk overfitting. A lower value makes smaller adjustments, possibly maintaining better generalization."
)
}, )
lora_dropout: Optional[float] = field(
default=0.1,
metadata={
"help": (
"during training to prevent the model from overly relying on specific patterns in the training data. "
"Higher dropout rates can improve model generalization but may reduce learning efficiency."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
max_seq_length: Optional[int] = field(
default=2048,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
)
},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
train_format: str = field(
default=None, metadata={"help": "The format of the training data file (mulit-turn or input-output)"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
def __post_init__(self):
extension = self.train_file.split(".")[-1]
assert extension in {"jsonl", "json"}, "`train_file` should be a jsonl or a json file."
assert self.train_format in {"multi-turn", "input-output"}
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
# Adapted from
import logging
import os
import sys
import torch
import json
import transformers
from transformers import (
AutoModel,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from trainer import LoRATrainer
from arguments import ModelArguments, DataTrainingArguments
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from preprocess_utils import sanity_check, InputOutputDataset
logger = logging.getLogger(__name__)
class CastOutputToFloat(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs).float()
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True).half().cuda()
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)
with open(data_args.train_file, "r", encoding="utf-8") as f:
if data_args.train_file.endswith(".json"):
train_data = json.load(f)
elif data_args.train_file.endswith(".jsonl"):
train_data = [json.loads(line) for line in f]
if data_args.train_format == "input-output":
train_dataset = InputOutputDataset(
train_data,
tokenizer,
data_args.max_source_length,
data_args.max_target_length,
)
else:
raise ValueError(f"Unknown train format: {data_args.train_format}")
print(f"Train dataset size: {len(train_dataset)}")
#if training_args.local_rank < 1:
sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer)
# Apply PEFT configuration
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=model_args.lora_rank,
target_modules=['query_key_value'],
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
)
model = get_peft_model(model, peft_config).to("cuda")
# 确保梯度检查点和模型并行化设置正确
#model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.is_parallelizable = True
model.model_parallel = True # 可以尝试暂时关闭模型并行化来看是否解决问题
model.lm_head = CastOutputToFloat(model.transformer.output_layer)
model.config.use_cache = False
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=None,
padding=False
)
# Initialize our Trainer
trainer = LoRATrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
trainer.save_state()
if __name__ == "__main__":
main()
\ No newline at end of file
import argparse
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
import os
from peft import get_peft_model, LoraConfig, TaskType
# Argument Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=None,
help="The directory of the model")
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
parser.add_argument("--lora-path", type=str, default=None,
help="Path to the LoRA model checkpoint")
parser.add_argument("--device", type=str, default="cuda", help="Device to use for computation")
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum new tokens for generation")
parser.add_argument("--lora-alpha", type=float, default=32, help="LoRA alpha")
parser.add_argument("--lora-rank", type=int, default=8, help="LoRA r")
parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
# Model and Tokenizer Configuration
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model, load_in_8bit=False, trust_remote_code=True, device_map="auto").to(
args.device)
# LoRA Model Configuration
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
target_modules=['query_key_value'],
r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout
)
model = get_peft_model(model, peft_config)
if os.path.exists(args.lora_path):
model.load_state_dict(torch.load(args.lora_path), strict=False)
# Interactive Prompt
while True:
prompt = input("Prompt: ")
inputs = tokenizer(prompt, return_tensors="pt").to(args.device)
response = model.generate(input_ids=inputs["input_ids"],
max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
response = response[0, inputs["input_ids"].shape[-1]:]
print("Response:", tokenizer.decode(response, skip_special_tokens=True))
from transformers import PreTrainedTokenizer
from torch.utils.data import Dataset
from typing import Dict, List
def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer):
print("Sanity Check >>>>>>>>>>>>>")
for t, m in zip(tokens, target):
decoded = tokenizer.tokenizer.index_special_tokens[t] \
if t in tokenizer.tokenizer.index_special_tokens \
else tokenizer.decode([t])
print("%20s: %6d -> %6d" % (repr(decoded), t, m))
print("<<<<<<<<<<<<< Sanity Check")
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
class InputOutputDataset(Dataset):
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
super(InputOutputDataset, self).__init__()
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.max_seq_length = max_source_length + max_target_length + 1
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i) -> dict:
data_item = self.data[i]
a_ids = self.tokenizer.encode(text=data_item['context'], add_special_tokens=True, truncation=True,
max_length=self.max_source_length)
b_ids = self.tokenizer.encode(text=data_item['target'], add_special_tokens=False, truncation=True,
max_length=self.max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
pad_len = self.max_seq_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [self.tokenizer.pad_token_id] * pad_len
labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"
return {
"input_ids": input_ids,
"labels": labels
}
tqdm
datasets
fsspec
astunparse
peft
accelerate
sentencepiece
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment