Commit 49314d92 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
miniCPM-bf16 @ 10f760ed
Subproject commit 10f760ed0246a8bc6bbc3742f428306a5afabe07
# 转换为 ChatML 格式
import os
import shutil
import json
input_dir = "data/AdvertiseGen"
output_dir = "data/AdvertiseGenChatML"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
for fn in ["train.json", "dev.json"]:
data_out_list = []
with open(os.path.join(input_dir, fn), "r") as f, open(os.path.join(output_dir, fn), "w") as fo:
for line in f:
if len(line.strip()) > 0:
data = json.loads(line)
data_out = {
"messages": [
{
"role": "user",
"content": data["content"],
},
{
"role": "assistant",
"content": data["summary"],
},
]
}
data_out_list.append(data_out)
json.dump(data_out_list, fo, ensure_ascii=False, indent=4)
This diff is collapsed.
This diff is collapsed.
from typing import Dict
from typing import List
from typing import Tuple
import argparse
import gradio as gr
import torch
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
# init gradio demo host and port
server_name=args.server_name
server_port=args.server_port
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
str: real-time generation results of hf model
"""
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(
enc,
do_sample=True,
top_k=0,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_new_tokens=max_dec_len,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
answer = ""
for new_text in streamer:
answer += new_text
yield answer[4 + len(inputs):]
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate after hitting "submit" button
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
query (str): query of current round
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
for q, a in chat_history:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": query})
# yield model generation
chat_history.append([query, ""])
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history
def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""re-generate the answer of last round's query
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
for q, a in chat_history[:-1]:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": chat_history[-1][0]})
# yield model generation
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history
def clear_history():
"""clear all chat history
Returns:
List: empty chat history
"""
return []
def reverse_last_round(chat_history):
"""reverse last round QA and keep the chat history before
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
"""
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1]
# launch gradio demo
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""# MiniCPM Gradio Demo""")
with gr.Row():
with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
demo.queue()
demo.launch(server_name=server_name, server_port=server_port, show_error=True)
from typing import Dict
from typing import List
from typing import Tuple
import argparse
import gradio as gr
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = "bfloat16"
elif torch_dtype == "float32":
torch_dtype = "float32"
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype)
# init gradio demo host and port
server_name=args.server_name
server_port=args.server_port
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
str: real-time generation results of hf model
"""
prompt = ""
assert len(dialog) % 2 == 1
for info in dialog:
if info["role"] == "user":
prompt += "<用户>" + info["content"]
else:
prompt += "<AI>" + info["content"]
prompt += "<AI>"
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop": None,
"stop_token_ids": None,
"ignore_eos": False,
"max_tokens": max_dec_len,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)[0]
generated_text = outputs.outputs[0].text
return generated_text
def generate(chat_history: List, query: str, top_p: float, temperature: float, max_dec_len: int):
"""generate after hitting "submit" button
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
query (str): query of current round
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
for q, a in chat_history:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": query})
# yield model generation
model_output = vllm_gen(model_input, top_p, temperature, max_dec_len)
chat_history.append([query, model_output])
return gr.update(value=""), chat_history
def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len: int):
"""re-generate the answer of last round's query
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
for q, a in chat_history[:-1]:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": chat_history[-1][0]})
# yield model generation
model_output = vllm_gen(model_input, top_p, temperature, max_dec_len)
chat_history[-1][1] = model_output
return gr.update(value=""), chat_history
def clear_history():
"""clear all chat history
Returns:
List: empty chat history
"""
return []
def reverse_last_round(chat_history):
"""reverse last round QA and keep the chat history before
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
"""
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1]
# launch gradio demo
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""# MiniCPM Gradio Demo""")
with gr.Row():
with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, max_dec_len], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, max_dec_len], outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
demo.queue()
demo.launch(server_name=server_name, server_port=server_port, show_error=True)
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10-py38
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
RUN source /opt/dtk-23.10/env.sh
# 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -r requirements.txt
# for finetune
jieba>=0.42.1
ruamel_yaml>=0.18.5
rouge_chinese>=1.0.3
jupyter>=1.0.0
datasets>=2.16.1
peft>=0.7.1
# deepspeed>=0.13.1
# flash_attn>=2.5.1
# MiniCPM 微调
[English Version](https://github.com/OpenBMB/MiniCPM/blob/main/finetune/README_en.md)
本目录提供 MiniCPM-2B 模型的微调示例,包括全量微调和 PEFT。格式上,提供多轮对话微调样例和输入输出格式微调样例。
如果将模型下载到了本地,本文和代码中的 `OpenBMB/MiniCPM-2B` 字段均应替换为相应地址以从本地加载模型。
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖。
**我们提供了 [示例notebook](lora_finetune.ipynb) 用于演示如何以 AdvertiseGen 为例处理数据和使用微调脚本。**
```bash
pip install -r requirements.txt
```
## 测试硬件标准
我们仅提供了单机多卡/多机多卡的运行示例,因此您需要至少一台具有多个 GPU 的机器。本仓库中的**默认配置文件**中,我们记录了显存的占用情况:
+ SFT 全量微调: 4张显卡平均分配,每张显卡占用 `30245MiB` 显存。
+ LORA 微调: 1张显卡,占用 `10619MiB` 显存。
> 请注意,该结果仅供参考,对于不同的参数,显存占用可能会有所不同。请结合你的硬件情况进行调整。
## 多轮对话格式
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`
对于数据文件,样例采用如下格式
```json
[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
## 数据集格式示例
这里以 AdvertiseGen 数据集为例,
您可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载 AdvertiseGen 数据集。
将解压后的 AdvertiseGen 目录放到 `data` 目录下并自行转换为如下格式数据集。
> 请注意,现在的微调代码中加入了验证集,因此,对于一组完整的微调数据集,必须包含训练数据集和验证数据集,测试数据集可以不填写。或者直接用验证数据集代替。
```
{"messages": [{"role": "user", "content": "类型#裙*裙长#半身裙"}, {"role": "assistant", "content": "这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。"}]}
```
## 开始微调
通过以下代码执行 **单机多卡/多机多卡** 运行。
```bash
cd finetune
bash sft_finetune.sh
```
通过以下代码执行 **单机单卡** 运行。
```angular2html
cd finetune
bash lora_finetune.sh
```
# MiniCPM Fine-tuning
[中文版](https://github.com/OpenBMB/MiniCPM/blob/main/finetune/README.md)
This directory provides examples of fine-tuning the MiniCPM-2B model, including full model fine-tuning and PEFT. In terms of format, we offer examples for multi-turn dialogue fine-tuning and input-output format fine-tuning.
If you have downloaded the model to your local system, the `OpenBMB/MiniCPM-2B` field mentioned in this document and in the code should be replaced with the corresponding address to load the model from your local system.
Running the example requires `python>=3.10`. Besides the basic `torch` dependency, additional dependencies are needed to run the example code.
**We have provided an [example notebook](lora_finetune.ipynb) to demonstrate how to process data and use the fine-tuning script with AdvertiseGen as an example.**
```bash
pip install -r requirements.txt
```
## Testing Hardware Standard
We only provide examples for single-node multi-GPU/multi-node multi-GPU setups, so you will need at least one machine with multiple GPUs. In the **default configuration file** in this repository, we have documented the memory usage:
+ SFT full parameters fine-tuning: Evenly distributed across 4 GPUs, each GPU consumes `30245MiB` of memory.
+ LORA fine-tuning: One GPU, consuming `10619MiB` of memory.。
> Please note that these results are for reference only, and memory consumption may vary with different parameters. Please adjust according to your hardware situation.
## Multi-Turn Dialogue Format
The multi-turn dialogue fine-tuning example adopts the ChatGLM3 dialogue format convention, adding different `loss_mask` for different roles, thus calculating `loss` for multiple replies in one computation.
For the data file, the example uses the following format
```json
[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
## Dataset Format Example
Here, taking the AdvertiseGen dataset as an example,
you can download the AdvertiseGen dataset from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) . After extracting the AdvertiseGen directory, place it in the `data` directory and convert it into the following format dataset.
> Please note, the fine-tuning code now includes a validation set, so for a complete set of fine-tuning datasets, it must contain training and validation datasets, while the test dataset is optional. Or, you can use the validation dataset in place of it.
```
{"messages": [{"role": "user", "content": "类型#裙*裙长#半身裙"}, {"role": "assistant", "content": "这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。"}]}
```
## Start Fine-tuning
Execute **single-node multi-GPU/multi-node multi-GPU** runs with the following code.
```bash
cd finetune
bash sft_finetune.sh
```
Execute **single-node single-GPU** runs with the following code.
```angular2html
cd finetune
bash lora_finetune.sh
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment