Commit 1bfbcff0 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1204 canceled with stages
# AnimateDiff的微调和推理
SWIFT已经支持了AnimateDiff的微调和推理,目前支持两种方式:全参数微调和LoRA微调。
首先需要clone并安装SWIFT:
```shell
git clone https://github.com/modelscope/swift.git
cd swift
pip install ".[aigc]"
```
## 全参数训练
### 训练效果
全参数微调可以复现[官方提供的模型animatediff-motion-adapter-v1-5-2](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/animatediff-motion-adapter-v1-5-2/summary)的效果,需要的短视频数量较多,魔搭官方复现使用了官方数据集的subset版本:[WebVid 2.5M](https://maxbain.com/webvid-dataset/)。训练效果如下:
```text
Prompt:masterpiece, bestquality, highlydetailed, ultradetailed, girl, walking, on the street, flowers
```
![image.png](../../resources/1.gif)
```text
Prompt: masterpiece, bestquality, highlydetailed, ultradetailed, beautiful house, mountain, snow top
```
![image.png](../../resources/2.gif)
2.5M子数据集训练的生成效果仍存在效果不稳定的情况,开发者使用10M数据集效果会更稳定。
### 运行命令
```shell
# 该文件在swift/examples/pytorch/animatediff/scripts/full中
# Experimental environment: A100 * 4
# 200GB GPU memory totally
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
torchrun --nproc_per_node=4 animatediff_sft.py \
--model_id_or_path wyj123456/Realistic_Vision_V5.1_noVAE \
--csv_path /mnt/workspace/yzhao/tastelikefeet/webvid/results_2M_train.csv \
--video_folder /mnt/workspace/yzhao/tastelikefeet/webvid/videos2 \
--sft_type full \
--lr_scheduler_type constant \
--trainable_modules .*motion_modules.* \
--batch_size 4 \
--eval_steps 100 \
--gradient_accumulation_steps 16
```
我们使用了A100 * 4进行训练,共需要200GB显存,训练时长约40小时。数据格式如下:
```text
--csv_path 传入一个csv文件,该csv文件应包含如下格式:
name,contentUrl
Travel blogger shoot a story on top of mountains. young man holds camera in forest.,stock-footage-travel-blogger-shoot-a-story-on-top-of-mountains-young-man-holds-camera-in-forest.mp4
```
name字段代表该短视频的prompt,contentUrl代表该视频文件的名称
```text
--video_folder 传入一个视频目录,该目录中包含了csv文件中,contentUrl指代的所有视频文件
```
使用全参数进行推理方式如下:
```shell
# 该文件在swift/examples/pytorch/animatediff/scripts/full中
# Experimental environment: A100
# 18GB GPU memory
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python animatediff_infer.py \
--model_id_or_path wyj123456/Realistic_Vision_V5.1_noVAE \
--sft_type full \
--ckpt_dir /output/path/like/checkpoints/iter-xxx \
--eval_human true
```
其中的--ckpt_dir 传入训练时输出的文件夹即可。
## LoRA训练
### 运行命令
全参数训练会从0开始训练整个Motion-Adapter结构,用户可以使用现有的模型使用少量视频进行微调,只需要运行下面的命令:
```shell
# 该文件在swift/examples/pytorch/animatediff/scripts/lora中
# Experimental environment: A100
# 20GB GPU memory
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python animatediff_sft.py \
--model_id_or_path wyj123456/Realistic_Vision_V5.1_noVAE \
--csv_path /mnt/workspace/yzhao/tastelikefeet/webvid/results_2M_train.csv \
--video_folder /mnt/workspace/yzhao/tastelikefeet/webvid/videos2 \
--motion_adapter_id_or_path Shanghai_AI_Laboratory/animatediff-motion-adapter-v1-5-2 \
--sft_type lora \
--lr_scheduler_type constant \
--trainable_modules .*motion_modules.* \
--batch_size 1 \
--eval_steps 200 \
--dataset_sample_size 10000 \
--gradient_accumulation_steps 16
```
视频数据参数同上。
推理命令如下:
```shell
# 该文件在swift/examples/pytorch/animatediff/scripts/lora中
# Experimental environment: A100
# 18GB GPU memory
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python animatediff_infer.py \
--model_id_or_path wyj123456/Realistic_Vision_V5.1_noVAE \
--motion_adapter_id_or_path Shanghai_AI_Laboratory/animatediff-motion-adapter-v1-5-2 \
--sft_type lora \
--ckpt_dir /output/path/like/checkpoints/iter-xxx \
--eval_human true
```
其中的--ckpt_dir 传入训练时输出的文件夹即可。
## 参数列表
下面给出训练和推理分别支持的参数列表及其含义:
### 训练参数
```text
motion_adapter_id_or_path: Optional[str] = None # motion adapter的模型id或模型路径,指定这个参数可以基于现有的官方模型效果继续训练
motion_adapter_revision: Optional[str] = None # motion adapter的模型revision,仅在motion_adapter_id_or_path是模型id时有用
model_id_or_path: str = None # sd基模型的模型id或模型路径
model_revision: str = None # sd基模型的revision,仅在model_id_or_path是模型id时有用
dataset_sample_size: int = None # 数据集训练条数,默认代表全量训练
sft_type: str = field(
default='lora', metadata={'choices': ['lora', 'full']}) # 训练方式,支持lora和全参数
output_dir: str = 'output' # 输出文件夹
ddp_backend: str = field(
default='nccl', metadata={'choices': ['nccl', 'gloo', 'mpi', 'ccl']}) # 如使用ddp训练,ddp backend
seed: int = 42 # 随机种子
lora_rank: int = 8 # lora 参数
lora_alpha: int = 32 # lora 参数
lora_dropout_p: float = 0.05 # lora 参数
lora_dtype: str = 'fp32' # lora模块dtype类型,如果为`AUTO`则跟随原始模块的dtype设定
gradient_checkpointing: bool = False # 是否开启gc,默认不开启。注:当前版本diffusers有问题,不支持该参数为True
batch_size: int = 1 # batchsize
num_train_epochs: int = 1 # epoch数
# if max_steps >= 0, override num_train_epochs
learning_rate: Optional[float] = None # 学习率
weight_decay: float = 0.01 # adamw参数
gradient_accumulation_steps: int = 16 # ga大小
max_grad_norm: float = 1. # grad norm大小
lr_scheduler_type: str = 'cosine' # lr_scheduler的类型
warmup_ratio: float = 0.05 # 是否warmup及warmup占比
eval_steps: int = 50 # eval step间隔
save_steps: Optional[int] = None # save step间隔
dataloader_num_workers: int = 1 # dataloader workers数量
push_to_hub: bool = False # 是否推送到modelhub
# 'user_name/repo_name' or 'repo_name'
hub_model_id: Optional[str] = None # modelhub id
hub_private_repo: bool = False
push_hub_strategy: str = field( # 推送策略,推送最后一个还是每个都推送
default='push_best',
metadata={'choices': ['push_last', 'all_checkpoints']})
# None: use env var `MODELSCOPE_API_TOKEN`
hub_token: Optional[str] = field( # modelhub的token
default=None,
metadata={
'help':
'SDK token can be found in https://modelscope.cn/my/myaccesstoken'
})
ignore_args_error: bool = False # True: notebook compatibility
text_dropout_rate: float = 0.1 # drop一定比例的文本保证模型鲁棒性
validation_prompts_path: str = field( # 评测过程使用的prompt文件目录,默认使用swift/aigc/configs/validation.txt
default=None,
metadata={
'help':
'The validation prompts file path, use aigc/configs/validation.txt is None'
})
trainable_modules: str = field( # 可训练模块,建议使用默认值
default='.*motion_modules.*',
metadata={
'help':
'The trainable modules, by default, the .*motion_modules.* will be trained'
})
mixed_precision: bool = True # 混合精度训练
enable_xformers_memory_efficient_attention: bool = True # 使用xformers
num_inference_steps: int = 25 #
guidance_scale: float = 8.
sample_size: int = 256
sample_stride: int = 4 # 训练视频最大长度秒数
sample_n_frames: int = 16 # 每秒帧数
csv_path: str = None # 输入数据集
video_folder: str = None # 输入数据集
motion_num_attention_heads: int = 8 # motion adapter参数
motion_max_seq_length: int = 32 # motion adapter参数
num_train_timesteps: int = 1000 # 推理pipeline参数
beta_start: int = 0.00085 # 推理pipeline参数
beta_end: int = 0.012 # 推理pipeline参数
beta_schedule: str = 'linear' # 推理pipeline参数
steps_offset: int = 1 # 推理pipeline参数
clip_sample: bool = False # 推理pipeline参数
use_wandb: bool = False # 是否使用wandb
```
### 推理参数
```text
motion_adapter_id_or_path: Optional[str] = None # motion adapter的模型id或模型路径,指定这个参数可以基于现有的官方模型效果继续训练
motion_adapter_revision: Optional[str] = None # motion adapter的模型revision,仅在motion_adapter_id_or_path是模型id时有用
model_id_or_path: str = None # sd基模型的模型id或模型路径
model_revision: str = None # sd基模型的revision,仅在model_id_or_path是模型id时有用
sft_type: str = field(
default='lora', metadata={'choices': ['lora', 'full']}) # 训练方式,支持lora和全参数
ckpt_dir: Optional[str] = field(
default=None, metadata={'help': '/path/to/your/vx-xxx/checkpoint-xxx'}) # 训练的输出文件夹
eval_human: bool = False # False: eval val_dataset # 是否使用人工输入评测
seed: int = 42 # 随机种子
merge_lora: bool = False # Merge lora into the MotionAdapter and save the model.
replace_if_exists: bool = False # Replace the files if the output merged dir exists when `merge_lora` is True.
# other
ignore_args_error: bool = False # True: notebook compatibility
validation_prompts_path: str = None # 用于validation的文件,eval_human=False时使用,每一行一个prompt
output_path: str = './generated' # 输出gif的目录
enable_xformers_memory_efficient_attention: bool = True # 使用xformers
num_inference_steps: int = 25 #
guidance_scale: float = 8.
sample_size: int = 256
sample_stride: int = 4 # 训练视频最大长度秒数
sample_n_frames: int = 16 # 每秒帧数
motion_num_attention_heads: int = 8 # motion adapter参数
motion_max_seq_length: int = 32 # motion adapter参数
num_train_timesteps: int = 1000 # 推理pipeline参数
beta_start: int = 0.00085 # 推理pipeline参数
beta_end: int = 0.012 # 推理pipeline参数
beta_schedule: str = 'linear' # 推理pipeline参数
steps_offset: int = 1 # 推理pipeline参数
clip_sample: bool = False # 推理pipeline参数
```
# Res-Tuning组件
<div align="center">
## [NeurIPS 2023] Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone
### [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/)
</div>
Res-Tuning 是一种灵活高效的微调tuner。我们把tuner的设计从模型网络结构中解耦出来以便灵活地组合,
并进一步扩展实现了一种新的节省内存的旁路tuner,大大减少了显存消耗和多任务推理成本。
目前Res-Tuning在[SWIFT](https://github.com/modelscope/swift)中以可插拔的tuner算法组件提供,开发者可以直接使用它。
### 支持的组件列表
- [x] Res-Adapter
- [x] Res-Tuning-Bypass
- [ ] Res-Prefix
- [ ] Res-Prompt
### 使用方式
#### Demo
- 可以使用我们提供的 [可视化例子](https://github.com/modelscope/swift/blob/main/examples/pytorch/cv/notebook/swift_vision.ipynb).
#### 初始化Tuner
```Python
from swift import ResTuningConfig
config = ResTuningConfig(
dims=768,
root_modules=r'.*blocks.0$',
stem_modules=r'.*blocks\.\d+$',
target_modules=r'norm',
tuner_cfg='res_adapter'
)
```
- dims: The dimensions of the hidden states.
- root_modules: The root module to be replaced.
- stem_modules: The stem modules to be replaced.
- target_modules: The target module to be replaced.
- tuner_cfg: The configuration of the tuning module.
#### 加载模型
```Python
from swift import Swift
import timm, torch
model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=100)
model_tune = Swift.prepare_model(model, config)
print(model_tune.get_trainable_parameters())
print(model(torch.ones(1, 3, 224, 224)).shape)
```
### 引用
```
@inproceedings{jiang2023restuning,
title={Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone},
author={Jiang, Zeyinzi and Mao, Chaojie and Huang, Ziyuan and Ma, Ao and Lv, Yiliang and Shen, Yujun and Zhao, Deli and Zhou, Jingren},
booktitle={Advances in Neural Information Processing Systems},
year={2023}
}
```
## 🔥SCEdit
SCEdit由阿里巴巴通义实验室视觉智能团队(Alibaba TongYi Vision Intelligence Lab)所提出,是一个高效的生成式微调框架。该框架不仅支持文生图下游任务的微调能力,**相比LoRA节省30%-50%的训练显存开销**,实现快速迁移到特定的生成场景中;而且还可以**直接扩展到可控图像生成任务中,仅需ControlNet条件生成7.9%的参数量并节省30%的显存开销**,支持边缘图、深度图、分割图、姿态、颜色图、图像补全等条件生成任务。
我们使用了[风格迁移数据集](https://modelscope.cn/datasets/damo/style_custom_dataset/dataPeview)中的3D风格数据进行了训练,并使用相同的`Prompt: A boy in a camouflage jacket with a scarf`进行测试,具体的定性和定量的结果如下:
| Method | bs | ep | Target Module | Param. (M) | Mem. (MiB) | 3D style |
| --------- | ---- | ---- | ------------- | ------------- | ---------- | ------------------------------------------------------------ |
| LoRA/r=64 | 1 | 50 | q/k/v/out/mlp | 23.94 (2.20%) | 8440MiB | <img src="https://intranetproxy.alipay.com/skylark/lark/0/2023/png/167218/1703665229562-0f33bbb0-c492-41b4-9f37-3ae720dca80d.png" alt="img" style="zoom:20%;" /> |
| SCEdit | 1 | 50 | up_blocks | 19.68 (1.81%) | 7556MiB | <img src="https://intranetproxy.alipay.com/skylark/lark/0/2023/png/167218/1703665933913-74b98741-3b57-46a4-9871-539df3a0112c.png" alt="img" style="zoom:20%;" /> |
| LoRA/r=64 | 10 | 100 | q/k/v/out/mlp | 23.94 (2.20%) | 26300MiB | <img src="https://intranetproxy.alipay.com/skylark/lark/0/2023/png/167218/1703750608529-de20d0e7-bf9c-4928-8e59-73cc54f2c8d7.png" alt="img" style="zoom:20%;" /> |
| SCEdit | 10 | 100 | up_blocks | 19.68 (1.81%) | 18634MiB | <img src="https://intranetproxy.alipay.com/skylark/lark/0/2023/png/167218/1703663033092-94492e44-341f-4259-9df4-13c168e3b5d6.png" alt="img" style="zoom:20%;" /> |
| LoRA/r=64 | 30 | 200 | q/k/v/out/mlp | 23.94 (2.20%) | 69554MiB | <img src="https://intranetproxy.alipay.com/skylark/lark/0/2023/png/167218/1703750626635-2e368d7b-5e99-4a06-b189-8615f302bcd7.png" alt="img" style="zoom:20%;" /> |
| SCEdit | 30 | 200 | up_blocks | 19.68 (1.81%) | 43350MiB | <img src="https://intranetproxy.alipay.com/skylark/lark/0/2023/png/167218/1703662246942-1102b1f4-93ab-4653-b943-3302f2a5259e.png" alt="img" style="zoom:20%;" /> |
使用SCEdit执行训练任务并复现上述结果:
```shell
# 先执行下面章节的安装步骤
cd examples/pytorch/multi_modal/notebook
python text_to_image_synthesis.py
```
# 安装和使用
## Wheel包安装
可以使用pip进行安装:
```shell
# 全量能力
pip install 'ms-swift[all]' -U
# 仅使用LLM
pip install 'ms-swift[llm]' -U
# 仅使用AIGC
pip install 'ms-swift[aigc]' -U
# 仅使用adapters
pip install ms-swift -U
```
## 源代码安装
```shell
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[all]'
```
## Notebook环境
Swift支持训练的绝大多数模型都可以在`A10`显卡上使用,用户可以使用ModelScope官方提供的免费显卡资源:
1. 进入[ModelScope](https://www.modelscope.cn)官方网站并登录
2. 点击左侧的`我的Notebook`并开启一个免费GPU实例
3. 愉快地薅A10显卡羊毛
## Build文档
Swift支持完整的API Doc文档,在swift根目录下执行:
```shell
make docs
```
等待执行完成后,查看`docs/build/html/index.html`即可。
# 基本使用
tuner是指附加在模型上的额外结构部分,用于减少训练参数量或者提高训练精度。目前SWIFT支持的tuners有:
1. LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685)
2. LoRA+: [LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
3. LLaMA PRO: [LLAMA PRO: Progressive LLaMA with Block Expansion](https://arxiv.org/pdf/2401.02415.pdf)
4. GaLore: [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
5. LISA: [LISA: Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning](https://arxiv.org/abs/2403.17919)
6. UnSloth: https://github.com/unslothai/unsloth
7. SCEdit: [SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing](https://arxiv.org/abs/2312.11392) < [arXiv](https://arxiv.org/abs/2312.11392) | [Project Page](https://scedit.github.io/) >
8. NEFTune: [Noisy Embeddings Improve Instruction Finetuning](https://arxiv.org/abs/2310.05914)
9. LongLoRA: [Efficient Fine-tuning of Long-Context Large Language Models](https://arxiv.org/abs/2309.12307)
10. Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751)
11. Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119)
12. Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503)
13. Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) >
14. [PEFT](https://github.com/huggingface/peft)提供的tuners, 如IA3, AdaLoRA等
## 在训练中使用
调用`Swift.prepare_model()`来将tuners添加到模型上:
```python
from modelscope import Model
from swift import Swift, LoraConfig
import torch
model = Model.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16, device_map='auto')
lora_config = LoraConfig(
r=16,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.)
model = Swift.prepare_model(model, lora_config)
```
也可以同时使用多个tuners:
```python
from modelscope import Model
from swift import Swift, LoraConfig, AdapterConfig
import torch
model = Model.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16, device_map='auto')
lora_config = LoraConfig(
r=16,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.)
adapter_config = AdapterConfig(
dim=model.config.hidden_size,
target_modules=['mlp'],
method_name='forward',
hidden_pos=0,
adapter_length=32,
)
model = Swift.prepare_model(model, {'first_tuner': lora_config, 'second_tuner': adapter_config})
# use model to do other things
```
在使用多个tuners时,传入的第二个参数需要是Dict,key是tuner名字,value是tuner配置。
训练后可以调用:
```python
model.save_pretrained(save_directory='./output')
```
来存储模型checkpoint。模型的checkpoint文件只会包括tuners的权重,不会包含模型本身的权重。存储后的结构如下:
> outputs
>
> ​ |-- configuration.json
>
> ​ |-- first_tuner
>
> ​ |-- adapter_config.json
>
> ​ |-- adapter_model.bin
>
> ​ |-- second_tuner
>
> ​ |-- adapter_config.json
>
> ​ |-- adapter_model.bin
>
> ​ |-- ...
如果只传入单独的config,则会使用默认的名称`default`
> outputs
>
> ​ |-- configuration.json
>
> ​ |-- default
>
> ​ |-- adapter_config.json
>
> ​ |-- adapter_model.bin
>
> ​ |-- ...
### 完整的训练代码
```python
# A100 18G memory
from swift import Seq2SeqTrainer, Seq2SeqTrainingArguments
from modelscope import MsDataset, AutoTokenizer
from modelscope import AutoModelForCausalLM
from swift import Swift, LoraConfig
from swift.llm import get_template, TemplateType
import torch
# 拉起模型
model = AutoModelForCausalLM.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
lora_config = LoraConfig(
r=16,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.05)
model = Swift.prepare_model(model, lora_config)
tokenizer = AutoTokenizer.from_pretrained('ZhipuAI/chatglm3-6b', trust_remote_code=True)
dataset = MsDataset.load('AI-ModelScope/alpaca-gpt4-data-en', split='train')
template = get_template(TemplateType.chatglm3, tokenizer, max_length=1024)
def encode(example):
inst, inp, output = example['instruction'], example.get('input', None), example['output']
if output is None:
return {}
if inp is None or len(inp) == 0:
q = inst
else:
q = f'{inst}\n{inp}'
example, kwargs = template.encode({'query': q, 'response': output})
return example
dataset = dataset.map(encode).filter(lambda e: e.get('input_ids'))
dataset = dataset.train_test_split(test_size=0.001)
train_dataset, val_dataset = dataset['train'], dataset['test']
train_args = Seq2SeqTrainingArguments(
output_dir='output',
learning_rate=1e-4,
num_train_epochs=2,
eval_steps=500,
save_steps=500,
evaluation_strategy='steps',
save_strategy='steps',
dataloader_num_workers=4,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
logging_steps=10,
)
trainer = Seq2SeqTrainer(
model=model,
args=train_args,
data_collator=template.data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer)
trainer.train()
```
## 在推理时使用
使用`Swift.from_pretrained()`来拉起训练后存储的checkpoint:
```python
from modelscope import Model
from swift import Swift
import torch
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
model = Swift.from_pretrained(model, './output')
```
### 完整的推理代码
```python
# A100 14G memory
import torch
from modelscope import AutoModelForCausalLM, GenerationConfig
from modelscope import AutoTokenizer
from swift import Swift
from swift.llm import get_template, TemplateType, to_device
# 拉起模型
model = AutoModelForCausalLM.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16,
device_map='auto', trust_remote_code=True)
model = Swift.from_pretrained(model, 'output/checkpoint-xxx')
tokenizer = AutoTokenizer.from_pretrained('ZhipuAI/chatglm3-6b', trust_remote_code=True)
template = get_template(TemplateType.chatglm3, tokenizer, max_length=1024)
examples, tokenizer_kwargs = template.encode({'query': 'How are you?'})
if 'input_ids' in examples:
input_ids = torch.tensor(examples['input_ids'])[None]
examples['input_ids'] = input_ids
token_len = input_ids.shape[1]
generation_config = GenerationConfig(
max_new_tokens=1024,
temperature=0.3,
top_k=25,
top_p=0.8,
do_sample=True,
repetition_penalty=1.0,
num_beams=10,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
device = next(model.parameters()).device
examples = to_device(examples, device)
generate_ids = model.generate(
generation_config=generation_config,
**examples)
generate_ids = template.get_generate_ids(generate_ids, token_len)
print(tokenizer.decode(generate_ids, **tokenizer_kwargs))
# I'm an AI language model, so I don't have feelings or physical sensations. However, I'm here to assist you with any questions or tasks you may have. How can I help you today?
```
# 接口列表
## Swift类静态接口
- `Swift.prepare_model(model, config, **kwargs)`
- 接口作用:加载某个tuner到模型上,如果是PeftConfig的子类,则使用Peft库的对应接口加载tuner。在使用SwiftConfig的情况下,本接口可以传入SwiftModel实例并重复调用,此时和config传入字典的效果相同。
- 本接口支持并行加载不同类型的多个tuners共同使用
- 参数:
- `model`: `torch.nn.Module``SwiftModel`的实例,被加载的模型
- `config`: `SwiftConfig``PeftConfig`的实例,或者一个自定义tuner名称对config的字典
- 返回值:`SwiftModel``PeftModel`的实例
- `Swift.merge_and_unload(model)`
- 接口作用:将LoRA weights合并回原模型,并将LoRA部分完全卸载
- 参数:
- model: `SwiftModel``PeftModel`的实例,已加载LoRA的模型实例
- 返回值:None
- `Swift.merge(model)`
- 接口作用:将LoRA weights合并回原模型,不卸载LoRA部分
- 参数:
- model: `SwiftModel``PeftModel`的实例,已加载LoRA的模型实例
- 返回值:None
- `Swift.unmerge(model)`
- 接口作用:将LoRA weights从原模型weights中拆分回LoRA结构
- 参数:
- model: `SwiftModel``PeftModel`的实例,已加载LoRA的模型实例
- 返回值:None
- `Swift.save_to_peft_format(ckpt_dir, output_dir)`
- 接口作用:将存储的LoRA checkpoint转换为Peft兼容的格式。主要改变有:
- `default`会从对应的`default`文件夹中拆分到output_dir根目录中
- weights中的`{tuner_name}.`字段会被移除,如`model.layer.0.self.in_proj.lora_A.default.weight`会变为`model.layer.0.self.in_proj.lora_A.weight`
- weights中的key会增加`basemodel.model`前缀
- 注意:只有LoRA可以被转换,其他类型tuner由于Peft本身不支持,因此会报转换错误。此外,由于LoRAConfig中存在额外参数,如`dtype`,因此在这些参数有设定的情况下,不支持转换为Peft格式,此时可以手动删除adapter_config.json中的对应字段
- 参数:
- ckpt_dir:原weights目录
- output_dir:目标weights目录
- 返回值:None
- `Swift.from_pretrained(model, model_id, adapter_name, revision, **kwargs)`
- 接口作用:从存储的weights目录中加载起tuner到模型上,如果adapter_name不传,则会将model_id目录下所有的tuners都加载起来。同`prepare_model`相同,本接口可以重复调用
- 参数:
- model:`torch.nn.Module``SwiftModel`的实例,被加载的模型
- model_id:`str`类型,待加载的tuner checkpoint, 可以是魔搭hub的id,或者训练产出的本地目录
- adapter_name:`str``List[str]``Dict[str, str]`类型或`None`,待加载tuner目录中的tuner名称,如果为`None`则加载所有名称的tuners,如果是`str``List[str]`则只加载某些具体的tuner,如果是`Dict`,则将`key`指代的tuner加载起来后换成`value`的名字
- revision: 如果model_id是魔搭的id,则revision可以指定对应版本号
## SwiftModel接口
下面列出用户可能调用的接口列表,其他内部接口或不推荐使用的接口可以通过`make docs`命令查看API Doc文档。
- `SwiftModel.create_optimizer_param_groups(self, **defaults)`
- 接口作用:根据加载的tuners创建parameter groups,目前仅对`LoRA+`算法有作用
- 参数:
- defaults:`optimizer_groups`的默认参数,如`lr``weight_decay`
- 返回值:
- 创建的`optimizer_groups`
- `SwiftModel.add_weighted_adapter(self, ...)`
- 接口作用:将已有的LoRA tuners合并为一个
- 参数:
- 本接口是PeftModel.add_weighted_adapter的透传,参数可以参考:[add_weighted_adapter文档](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter)
- `SwiftModel.save_pretrained(self, save_directory, safe_serialization, adapter_name)`
- 接口作用:存储tuner weights
- 参数:
- save_directory:存储目录
- safe_serialization: 是否使用safe_tensors,默认为False
- adapter_name:存储的adapter tuner,如果不传则默认存储所有的tuners
- `SwiftModel.set_active_adapters(self, adapter_names, offload=None)`
- 接口作用:设置当前激活的adapters,不在列表中的adapters会被失活
-`推理`时支持环境变量`USE_UNIQUE_THREAD=0/1`,默认值`1`,如果为`0`则set_active_adapters只对当前线程生效,此时默认使用本线程激活的tuners,不同线程tuners互不干扰
- 参数:
- adapter_names:激活的tuners
- offload:失活的adapters如何处理,默认为`None`代表留在显存中,同时支持`cpu``meta`,代表offload到cpu和meta设备中以减轻显存消耗,在`USE_UNIQUE_THREAD=0`时offload不要传值以免影响其他线程
- 返回值:None
- `SwiftModel.activate_adapter(self, adapter_name)`
- 接口作用:激活一个tuner
-`推理`时支持环境变量`USE_UNIQUE_THREAD=0/1`,默认值`1`,如果为`0`则activate_adapter只对当前线程生效,此时默认使用本线程激活的tuners,不同线程tuners互不干扰
- 参数:
- adapter_name:待激活的tuner名字
- 返回值:None
- `SwiftModel.deactivate_adapter(self, adapter_name, offload)`
- 接口作用:失活一个tuner
-`推理`时环境变量`USE_UNIQUE_THREAD=0`时不要调用本接口
- 参数:
- adapter_name:待失活的tuner名字
- offload:失活的adapters如何处理,默认为`None`代表留在显存中,同时支持`cpu``meta`,代表offload到cpu和meta设备中以减轻显存消耗
- 返回值:None
- `SwiftModel.get_trainable_parameters(self)`
- 接口作用:返回训练参数信息
- 参数:无
- 返回值:训练参数信息,格式如下:
```text
trainable params: 100M || all params: 1000M || trainable%: 10.00% || cuda memory: 10GiB.
```
# 对Peft的兼容性
为了支持习惯Peft的用户,Swift提供了对于Peft的兼容性。用户可以从swift中import peft组件:
>PeftModel
>
>PeftConfig
>
>PeftModelForSeq2SeqLM
>
>PeftModelForSequenceClassification
>
>PeftModelForTokenClassification
>
>PeftModelForCausalLM
>
>PromptEncoderConfig
>
>PromptTuningConfig
>
>PrefixTuningConfig
>
>PromptLearningConfig
>
>LoraConfig
>
>get_peft_config
>
>get_peft_model_state_dict
>
>get_peft_model
以上组件均可以从swift中import:
```python
from swift import PeftModel, PeftConfig
```
Swift类也支持初始化Peft的tuner:
```python
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.models.nlp.structbert import SbertConfig
from swift import LoraConfig, Swift
model = SbertForSequenceClassification(SbertConfig())
lora_config = LoraConfig(target_modules=['query', 'key', 'value'])
model = Swift.prepare_model(model, lora_config)
```
Swift对Peft进行了浅封装,使Peft可以在from_pretrained时使用modelscope hub中的模型。
# 界面训练推理
目前SWIFT已经支持了界面化的训练和推理,参数支持和脚本训练相同。在安装SWIFT后,使用如下命令:
```shell
swift web-ui
```
开启界面训练和推理。
web-ui没有传入参数,所有可控部分都在界面中。但是有几个环境变量可以使用:
> WEBUI_SHARE=1/0 默认为0 控制gradio是否是share状态
>
> SWIFT_UI_LANG=en/zh 控制web-ui界面语言
>
> WEBUI_SERVER server_name参数,web-ui host ip,0.0.0.0代表所有ip均可访问,127.0.0.1代表只允许本机访问
>
> WEBUI_PORT web-ui的端口号
>
> USE_INFERENCE=1/0 默认0. 控制gradio的推理页面是直接加载模型推理或者部署(USE_INFERENCE=0)
# Agent微调最佳实践
用消费级显卡训练属于自己的Agent!
SWIFT支持了开源模型,尤其是中小型模型(7B、14B等)对Agent场景的训练,并将[loss-scale技术](https://arxiv.org/pdf/2309.00986.pdf)应用到agent训练中,使中小模型API Call能力更稳定,并支持使用单张商业级显卡进行Agent推理和部署,可以直接在生产场景中全链路闭环落地使用。
## 目录
- [环境安装](#环境安装)
- [数据准备](#数据准备)
- [微调](#微调)
- [推理](#推理)
- [总结](#总结)
- [搭配Modelscope-Agent使用](#搭配Modelscope-Agent使用)
## 环境安装
```bash
# 设置pip全局镜像 (加速下载)
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
# 安装ms-swift
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
## 数据准备
为训练Agent能力,魔搭官方提供了两个开源数据集:
- [魔搭通用问答知识数据集](https://www.modelscope.cn/datasets/iic/ms_bench/summary) 该数据集包含了38万条通用知识多轮对话数据
- [魔搭通用Agent训练数据集](https://www.modelscope.cn/datasets/iic/ms_agent/summary) 该数据集包含了3万条Agent格式的API调用数据
该数据集数据格式如下:
```json
{
"id": "MS_Agent_Bench_126374",
"conversations": [{
"from": "system",
"value": "Answer the following questions as best you can. You have access to the following APIs:\n1. hm_recipe_recommend: Call this tool to interact with the hmreciperecommend API. What is the hmreciperecommend API useful for? . Parameters: [{\"name\": \"keywords_dict\", \"description\": \"盒马推荐菜谱关键词字典。\", \"required\": \"True\"}]\n\n2. hm_product_marketing: Call this tool to interact with the hmproductmarketing API. What is the hmproductmarketing API useful for? . Parameters: [{\"name\": \"sku_code_list\", \"description\": \"盒马商品的编码列表\", \"required\": \"True\"}]\n\n3. hm_product_info: Call this tool to interact with the hmproductinfo API. What is the hmproductinfo API useful for? . Parameters: [{\"name\": \"sku_code_list\", \"description\": \"盒马商品的编码列表\", \"required\": \"True\"}, {\"name\": \"sku_name_list\", \"description\": \"盒马商品的名称列表\", \"required\": \"True\"}, {\"name\": \"property_list\", \"description\": \"盒马商品的属性列表\", \"required\": \"True\"}]\n\n4. hm_product_recommend: Call this tool to interact with the hmproductrecommend API. What is the hmproductrecommend API useful for? . Parameters: [{\"name\": \"keywords_dict\", \"description\": \"盒马推荐商品关键词字典。\", \"required\": \"True\"}]\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[hm_recipe_recommend, hm_product_marketing, hm_product_info, hm_product_recommend]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!"
}, {
"from": "user",
"value": "你好,请问你们有哪些手机可以推荐?"
}, {
"from": "assistant",
"value": "Action: hm_product_recommend\nAction Input: {'keywords_dict': {'商品类目': '手机'}}\nObservation: [{'sku_code': '10518244314', 'product_name': '小米11 Ultra 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '5999', 'product_description': '拥有120Hz高刷新率和骁龙888处理器,还配备了高品质相机系统'}, {'sku_code': '10222726664', 'product_name': 'vivo X60 Pro 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '4498', 'product_description': '搭载Exynos 1080处理器与ZEISS联合调校的相机系统,拍照实力十足'}, {'sku_code': '18227073452', 'product_name': '华为畅享 20 Pro 5G手机 128GB', 'bussiness_type': '标品', 'product_price': '2699', 'product_description': '拥有4000mAh的电池容量和优秀的相机而价格非常亲民'}]\nThought: I now know the final answer\nFinal Answer: 您好,我们有以下几款手机供您选择:1.小米11 Ultra 5G手机 256GB:价格为5999元,拥有120Hz高刷新率和骁龙888处理器,还配备了高品质相机系统;2. vivo X60 Pro 5G手机 256GB:售价为4498元,搭载Exynos 1080处理器与ZEISS联合调校的相机系统,拍照实力十足;3. 华为畅享 20 Pro 5G手机 128GB:价格亲民,只需2699元即可拥有优秀的相机和4000mAh的电池容量。这些手机都具有不同的特点和功能,您可以根据自己的需求进行选择。如果您有其他问题或需求,请随时告诉我。"
}, {
"from": "user",
"value": "我想买一款拍照好的手机,有推荐吗?"
}, {
"from": "assistant",
"value": "Action: hm_product_recommend\nAction Input: {'keywords_dict': {'商品类目': '手机', '适用场景': '拍照'}}\nObservation: [{'sku_code': '15956662254', 'product_name': '小米11 Ultra 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '5999', 'product_description': '拥有120Hz高刷新率和骁龙888处理器,还配备了高品质相机系统'}, {'sku_code': '19424216918', 'product_name': 'vivo X60 Pro 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '4498', 'product_description': '搭载Exynos 1080处理器与ZEISS联合调校的相机系统,拍照实力十足'}, {'sku_code': '14347982715', 'product_name': 'OPPO Ace2 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '3999', 'product_description': '拍照实力十足,支持65W超级闪充'}]\nThought: I now know the final answer\nFinal Answer: 如果您想要一款拍照表现出色的手机,我为您推荐以下几款产品:1.小米11 Ultra 5G手机 256GB:价格为5999元,配备了高品质相机系统,可以拍摄出色的照片;2. vivo X60 Pro 5G手机 256GB:售价为4498元,相机系统由ZEISS联合调校,拍照实力十足;3. OPPO Ace2 5G手机 256GB:价格为3999元,不仅拍照实力十足,还支持65W超级闪充,让你的手机充电更加便捷。这些手机都有不同的特点和功能,您可以根据自己的需求进行选择。如果您还有其他问题或需求,请随时告诉我。"
}, {
"from": "user",
"value": "小米11 Ultra拍照性能怎么样?"
}, {
"from": "assistant",
"value": "Action: hm_product_info\nAction Input: {'sku_code_list': ['15956662254'], 'sku_name_list': ['小米11 Ultra 5G手机 256GB'], 'property_list': [['拍照性能']]}\nObservation: [{'sku_code': '15956662254', 'property': {'拍照性能': '小米11 Ultra 5G手机配备了1亿像素的主摄像头,还有120X超长焦电影摄像头,以及20MP超广角摄像头等组合,拍照性能非常出色。'}}]\nThought: I now know the final answer\nFinal Answer: 小米11 Ultra 5G手机配备了1亿像素的主摄像头,还有120X超长焦电影摄像头,以及20MP超广角摄像头等组合,拍照性能非常出色。如果您还有其他问题或需求,请随时告诉我。"
}]
}
```
Agent数据集的system字段具体格式如下(将\\"字符转换为"字符, \n转换为换行):
```text
Answer the following questions as best you can. You have access to the following APIs:
1. hm_recipe_recommend: Call this tool to interact with the hmreciperecommend API. What is the hmreciperecommend API useful for? . Parameters: [{"name": "keywords_dict", "description": "盒马推荐菜谱关键词字典。", "required": "True"}]
2. hm_product_marketing: Call this tool to interact with the hmproductmarketing API. What is the hmproductmarketing API useful for? . Parameters: [{"name": "sku_code_list", "description": "盒马商品的编码列表", "required": "True"}]
3. hm_product_info: Call this tool to interact with the hmproductinfo API. What is the hmproductinfo API useful for? . Parameters: [{"name": "sku_code_list", "description": "盒马商品的编码列表", "required": "True"}, {"name": "sku_name_list", "description": "盒马商品的名称列表", "required": "True"}, {"name": "property_list", "description": "盒马商品的属性列表", "required": "True"}]
4. hm_product_recommend: Call this tool to interact with the hmproductrecommend API. What is the hmproductrecommend API useful for? . Parameters: [{"name": "keywords_dict", "description": "盒马推荐商品关键词字典。", "required": "True"}]
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[hm_recipe_recommend, hm_product_marketing, hm_product_info, hm_product_recommend]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
```
API格式:
```text
Answer the following questions as best you can. You have access to the following APIs:
序号: API名称: API作用 API参数
...
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[API名称列表]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
```
Agent数据集调用API的response的结构如下:
```text
Action: hm_product_recommend
Action Input: {'keywords_dict': {'商品类目': '手机', '适用场景': '拍照'}}
Observation: [{'sku_code': '15956662254', 'product_name': '小米11 Ultra 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '5999', 'product_description': '拥有120Hz高刷新率和骁龙888处理器,还配备了高品质相机系统'}, {'sku_code': '19424216918', 'product_name': 'vivo X60 Pro 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '4498', 'product_description': '搭载Exynos 1080处理器与ZEISS联合调校的相机系统,拍照实力十足'}, {'sku_code': '14347982715', 'product_name': 'OPPO Ace2 5G手机 256GB', 'bussiness_type': '标品', 'product_price': '3999', 'product_description': '拍照实力十足,支持65W超级闪充'}]
Thought: I now know the final answer
Final Answer: 如果您想要一款拍照表现出色的手机,我为您推荐以下几款产品:1.小米11 Ultra 5G手机 256GB:价格为5999元,配备了高品质相机系统,可以拍摄出色的照片;2. vivo X60 Pro 5G手机 256GB:售价为4498元,相机系统由ZEISS联合调校,拍照实力十足;3. OPPO Ace2 5G手机 256GB:价格为3999元,不仅拍照实力十足,还支持65W超级闪充,让你的手机充电更加便捷。这些手机都有不同的特点和功能,您可以根据自己的需求进行选择。如果您还有其他问题或需求,请随时告诉我。
```
- Action:实际调用的API名称
- Action Input: 实际的输入参数
- Observation: 该部分是实际调用结果,训练时不参与loss,推理时需要外部调用后填入模型
- Thought: 模型思考输出
- Final Answer: 模型的最终回答
## 微调
在Agent训练中,为了避免训练后造成严重知识遗忘,我们的数据配比为[ms-agent](https://www.modelscope.cn/datasets/iic/ms_agent/summary):[ms-bench](https://www.modelscope.cn/datasets/iic/ms_bench/summary)数据集1比2,其中ms_agent共30000条,随机抽样ms_bench数据集60000条,同时为了改变模型认知,增加自我认知数据3000条。
| 数据集 | 条数 |
| ---------------- | --------------- |
| ms-agent | 30000(全数据集) |
| ms-bench | 60000(抽样) |
| self-recognition | 3000(重复抽样) |
我们也支持使用自己的Agent数据集。数据集格式需要符合[自定义数据集](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E8%87%AA%E5%AE%9A%E4%B9%89%E4%B8%8E%E6%8B%93%E5%B1%95.md#%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86)的要求。更具体地,Agent的response/system应该符合上述的Action/Action Input/Observation格式。
我们将**MLP****Embedder**加入了lora_target_modules. 你可以通过指定`--lora_target_modules ALL`在所有的linear层(包括qkvo以及mlp和embedder)加lora. 这**通常是效果最好的**.
微调使用了qwen-7b-chat模型,超参数如下:
| 超参数 | 值 |
| --------------------------- | -------- |
| LR | 5e-5 |
| Epoch | 2 |
| lora_rank | 8 |
| lora_alpha | 32 |
| lora_target_modules | ALL |
| batch_size | 2 |
| gradient_accumulation_steps | 32 total |
运行命令和其他超参数如下:
```shell
# Experimental environment: 8GPU
nproc_per_node=8
PYTHONPATH=../../.. \
torchrun \
--nproc_per_node=$nproc_per_node \
--master_port 29500 \
llm_sft.py \
--model_id_or_path qwen/Qwen-7B-Chat \
--model_revision master \
--sft_type lora \
--tuner_backend peft \
--dtype AUTO \
--output_dir output \
--dataset ms-agent \
--train_dataset_mix_ratio 2.0 \
--train_dataset_sample -1 \
--num_train_epochs 2 \
--max_length 1500 \
--check_dataset_strategy warning \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--self_cognition_sample 3000 \
--model_name 卡卡罗特 \
--model_author 陶白白 \
--gradient_checkpointing true \
--batch_size 2 \
--weight_decay 0.1 \
--learning_rate 5e-5 \
--gradient_accumulation_steps $(expr 32 / $nproc_per_node) \
--max_grad_norm 0.5 \
--warmup_ratio 0.03 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10
```
在官方实验中,训练过程使用了8GPU硬件环境,**训练时长3小时**
> [!NOTE]
>
> 1. 该训练使用消费级单显卡也可以运行(对应**占用显存22G**),用户将DDP命令改为单卡命令即可
>
> 2. LoRA训练的遗忘问题并不严重,可以适当调低ms-bench数据集的比例,提高训练速度
## 推理
我们针对通用知识和Agent进行评测。下面列出了一个简单的评测结果。
### 原始模型
#### 通用知识
> 西湖醋鱼怎么做
![image-20240201122323540](../../resources/image-20240201122323540.png)
> 新冠和普通感冒有什么区别
![image-20240201122441874](../../resources/image-20240201122441874.png)
#### Agent能力
我们使用一个火焰报警场景作为测试用例:
```text
Answer the following questions as best you can. You have access to the following APIs:
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]
2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []
3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []
4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
```
![image-20240201122625473](../../resources/image-20240201122625473.png)
![image-20240201122725477](../../resources/image-20240201122725477.png)
![image-20240201131811038](../../resources/image-20240201131811038.png)
可以看到,人工输入Observation后模型答案并不正确。
### 训练后
#### 通用知识
> 西湖醋鱼怎么做
![image-20240201132124061](../../resources/image-20240201132124061.png)
![image-20240201132139698](../../resources/image-20240201132139698.png)
> 新冠和普通感冒有什么区别
![image-20240201132308260](../../resources/image-20240201132308260.png)
#### Agent能力
![image-20240201132421298](../../resources/image-20240201132421298.png)
![image-20240201132454465](../../resources/image-20240201132454465.png)
可以看到,训练后模型可以正确调用API并给出最终答案。
#### 自我认知
![image-20240201133359457](../../resources/image-20240201133359457.png)
### 在命令行中使用Agent
目前命令行的Agent推理支持需要指定`--eval_human true`,因为该参数为false的时候会读取数据集内容,此时无法手动传入`Observation:`后面的API调用结果。
```shell
# 使用训练后的模型
swift infer --ckpt_dir output/qwen-7b-chat/vx-xxx/checkpoint-xxx --eval_human true --stop_words Observation: --infer_backend pt
# 也可以使用原始模型,如qwn-7b-chat或chatglm3-6b-32k等运行agent
# swift infer --model_type qwen-7b-chat --eval_human true --stop_words Observation: --infer_backend pt
# swift infer --model_type chatglm3-6b-32k --eval_human true --stop_words Observation: --infer_backend pt
```
运行命令后,改变system字段:
```shell
# 单行system
<<< reset-system
<<< Answer the following questions as best you can. You have access to the following APIs:\n1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]\n\n2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []\n\n3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []\n\n4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!
```
如果需要以多行方式输入,可以用下面的命令(多行信息以#号结束):
```shell
# 多行system
<<< multi-line
<<<[M] reset-system#
<<<[MS] Answer the following questions as best you can. You have access to the following APIs:
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]
2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []
3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []
4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!#
```
下面就可以进行Agent问答(注意如果使用多行模式输入行尾额外增加#号):
```shell
<<< 输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点
Thought: I need to use the fire\_recognition API to analyze the input image and determine if there are any signs of fire.
Action: Use the fire\_recognition API to analyze the input image.
Action Input: /tmp/1.jpg
Observation:
<<< [{'coordinate': [101.1, 200.9], 'on_fire': True}]
Thought: The fire\_recognition API has returned a result indicating that there is fire in the input image.
Final Answer: There is fire in the input image.
```
可以看到,模型已经返回了API调用的结果分析。用户可以继续问问题进行多轮Agent场景。也可以指定`--infer_backend vllm``--stream true`来使用vllm和流式推理。
### 在部署中使用Agent
由于部署不支持history管理,因此agent的API调用结果拼接需要用户自行进行,下面给出一个OpenAI格式可运行的代码范例。
服务端:
```shell
# 使用训练后的模型
swift deploy --ckpt_dir output/qwen-7b-chat/vx-xxx/checkpoint-xxx --stop_words Observation:
# 也可以使用原始模型,如qwen-7b-chat或chatglm3-6b-32k等运行agent
# swift deploy --model_type qwn-7b-chat --stop_words Observation:
# swift deploy --model_type chatglm3-6b-32k --stop_words Observation:
```
客户端:
```python
from openai import OpenAI
client = OpenAI(
api_key='EMPTY',
base_url='http://localhost:8000/v1',
)
model_type = client.models.list().data[0].id
print(f'model_type: {model_type}')
system = """Answer the following questions as best you can. You have access to the following APIs:
1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{\"name\": \"image\", \"description\": \"The input image to recognize fire\", \"required\": \"True\"}]
2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building's administraters. Parameters: []
3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []
4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!"""
messages = [{
'role': 'system',
'content': system
}, {
'role': 'user',
'content': '输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点'
}]
resp = client.chat.completions.create(
model=model_type,
messages=messages,
stop=['Observation:'],
seed=42)
response = resp.choices[0].message.content
print(f'response: {response}')
# # 流式
messages.append({'role': 'assistant', 'content': response + "\n[{'coordinate': [101.1, 200.9], 'on_fire': True}]"})
print(messages)
stream_resp = client.chat.completions.create(
model=model_type,
messages=messages,
stop=['Observation:'],
stream=True,
seed=42)
print('response: ', end='')
for chunk in stream_resp:
print(chunk.choices[0].delta.content, end='', flush=True)
print()
## Output:
# model_type: qwen-7b-chat
# response: Thought: I need to check if there is fire in the image
# Action: Use fire\_recognition API
# Action Input: /tmp/1.jpg
# Observation:
# [{'role': 'system', 'content': 'Answer the following questions as best you can. You have access to the following APIs:\n1. fire_recognition: Call this tool to interact with the fire recognition API. This API is used to recognize whether there is fire in the image. Parameters: [{"name": "image", "description": "The input image to recognize fire", "required": "True"}]\n\n2. fire_alert: Call this tool to interact with the fire alert API. This API will start an alert to warn the building\'s administraters. Parameters: []\n\n3. call_police: Call this tool to interact with the police calling API. This API will call 110 to catch the thief. Parameters: []\n\n4. call_fireman: Call this tool to interact with the fireman calling API. This API will call 119 to extinguish the fire. Parameters: []\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\nBegin!'}, {'role': 'user', 'content': '输入图片是/tmp/1.jpg,协助判断图片中是否存在着火点'}, {'role': 'assistant', 'content': "Thought: I need to check if there is fire in the image\nAction: Use fire\\_recognition API\nAction Input: /tmp/1.jpg\nObservation:\n[{'coordinate': [101.1, 200.9], 'on_fire': True}]"}]
# response:
# Final Answer: There is fire in the image at coordinates [101.1, 200.9]
```
## 搭配Modelscope-Agent使用
结合[Modelscope-Agent](https://github.com/modelscope/modelscope-agent),微调模型用于搭建Agent
本节针对Modelscope-Agent中的交互式框架AgentFabric,微调小模型qwen-7b-chat使其具有function call能力
由于ms-agent中的system prompt与Modelscope-Agent中的system prompt格式不匹配,直接训练效果不佳,为此我们根据ms-agent转换格式得到新数据集[ms_agent_for_agentfabric](https://modelscope.cn/datasets/AI-ModelScope/ms_agent_for_agentfabric/summary),现已集成到SWIFT中。
其中`ms-agent-for-agentfabric-default`包含3万条由ms-agent转换的数据集,`ms-agent-for-agentfabric-additional`包含488条由开源的AgentFabric框架实际调用访问数据筛选得到
### 微调
`dataset`换为`ms-agent-for-agentfabric-default``ms-agent-for-agentfabric-addition`
```shell
# Experimental environment: 8GPU
nproc_per_node=8
PYTHONPATH=../../.. \
torchrun \
--nproc_per_node=$nproc_per_node \
--master_port 29500 \
llm_sft.py \
--model_id_or_path qwen/Qwen-7B-Chat \
--model_revision master \
--sft_type lora \
--tuner_backend swift \
--dtype AUTO \
--output_dir output \
--dataset ms-agent-for-agentfabric-default ms-agent-for-agentfabric-addition \
--train_dataset_mix_ratio 2.0 \
--train_dataset_sample -1 \
--num_train_epochs 2 \
--max_length 1500 \
--check_dataset_strategy warning \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--self_cognition_sample 3000 \
--model_name 卡卡罗特 \
--model_author 陶白白 \
--gradient_checkpointing true \
--batch_size 2 \
--weight_decay 0.1 \
--learning_rate 5e-5 \
--gradient_accumulation_steps $(expr 32 / $nproc_per_node) \
--max_grad_norm 0.5 \
--warmup_ratio 0.03 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10
```
merge lora
```
CUDA_VISIBLE_DEVICES=0 swift export \
--ckpt_dir '/path/to/qwen-7b-chat/vx-xxx/checkpoint-xxx' --merge_lora true
```
### AgentFabric
#### 环境安装
```bash
git clone https://github.com/modelscope/modelscope-agent.git
cd modelscope-agent && pip install -r requirements.txt && pip install -r apps/agentfabric/requirements.txt
```
#### 部署模型
使用以下任意一种方式部署模型
##### swift deploy
```bash
CUDA_VISIBLE_DEVICES=0 swift deploy --ckpt_dir /path/to/qwen-7b-chat/vx-xxx/checkpoint-xxxx-merged
```
##### vllm
```bash
python -m vllm.entrypoints.openai.api_server --model /path/to/qwen-7b-chat/vx-xxx/checkpoint-xxxx-merged --trust-remote-code
```
#### 添加本地模型配置
`/path/to/modelscope-agent/apps/agentfabric/config/model_config.json`中,新增合并后的本地模型
```
"my-qwen-7b-chat": {
"type": "openai",
"model": "/path/to/qwen-7b-chat/vx-xxx/checkpoint-xxxx-merged",
"api_base": "http://localhost:8000/v1",
"is_chat": true,
"is_function_call": false,
"support_stream": false
}
```
注意,如果使用`swift deploy`部署,需要将`"model"`的值设为`qwen-7b-chat`
#### 启动AgentFabric
在以下实践中,会调用[Wanx Image Generation](https://help.aliyun.com/zh/dashscope/opening-service?spm=a2c4g.11186623.0.0.50724937O7n40B)[高德天气](https://lbs.amap.com/api/webservice/guide/create-project/get-key),需要手动设置API KEY, 设置后启动AgentFabric
```bash
export PYTHONPATH=$PYTHONPATH:/path/to/your/modelscope-agent
export DASHSCOPE_API_KEY=your_api_key
export AMAP_TOKEN=your_api_key
cd modelscope-agent/apps/agentfabric
python app.py
```
进入AgentFabric后,在配置(Configure)的模型中选择本地模型`my-qwen-7b-chat`
内置能力选择agent可以调用的API, 这里选择`Wanx Image Generation``高德天气`
点击更新配置,等待配置完成后在右侧的输入栏中与Agent交互
> 天气查询
![agentfabric_1](../../resources/agentfabric_1.png)
![agentfabric_2](../../resources/agentfabric_2.png)
> 文生图
![agentfabric_3](../../resources/agentfabric_3.png)
![agentfabric_4](../../resources/agentfabric_4.png)
可以看到微调后的模型可以正确理解指令并调用工具
## 总结
通过SWIFT支持的Agent训练能力,我们使用ms-agent和ms-bench对qwen-7b-chat模型进行了微调。可以看到微调后模型保留了通用知识问答能力,并在system字段增加了API的情况下可以正确调用并完成任务。需要注意的是:
1. 训练从LoRA变为全参数训练,知识遗忘问题会更加严重,数据集混合比例需要实际测试调整
2. 部分模型可能在训练后仍然调用效果不佳,可以测试该模型本身预训练能力是否扎实
3. Agent训练集格式、语种有细节改变后,对应推理阶段的格式也需要相应调整,否则可能效果不佳
4. 重要位置的`\n`等特殊字符比较重要,请注意推理和训练格式统一
# Agent部署最佳实践
## 目录
- [环境安装](#环境安装)
- [tools字段](#tools字段)
- [部署](#部署)
- [总结](#总结)
## 环境安装
```bash
# 设置pip全局镜像 (加速下载)
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
# 安装ms-swift
pip install 'ms-swift[llm]' -U
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
## tools字段
tools字段提供了模型可以调用的API信息。支持OpenAI和ToolBench格式,需要提供tools的名字,描述和参数,示例如下
OpenAI tools格式
```json
{
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
]
}
```
ToolBench tools 格式
```json
{
"tools": [
{
"name": "url_for_newapi",
"description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "",
"example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/"
}
},
"required": [
"url"
],
"optional": [
"url"
]
}
},
{
"name": "n_for_newapi",
"description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"n_for_newapiew var\"",
"parameters": {
"type": "object",
"properties": {
"language": {
"type": "string",
"description": "",
"example_value": "https://www.instagram.com/reels/Csb0AI3IYUN/"
}
},
"required": [
"language"
],
"optional": []
}
},
{
"name": "Finish",
"description": "If you believe that you have obtained a result that can answer the task, please call this function to provide the final answer. Alternatively, if you recognize that you are unable to proceed with the task in the current state, call this function to restart. Remember: you must ALWAYS call this function at the end of your attempt, and the only part that will be shown to the user is the final answer, so it should contain sufficient information.",
"parameters": {
"type": "object",
"properties": {
"return_type": {
"type": "string",
"enum": [
"give_answer",
"give_up_and_restart"
]
},
"final_answer": {
"type": "string",
"description": "The final answer you want to give the user. You should have this field if \"return_type\"==\"give_answer\""
}
},
"required": [
"return_type"
]
}
}
],
}
```
在推理过程中,会将tools的信息转换成对应的tools system prompt。如果已经存在system prompt,则会拼接在已有的之后。
目前支持英文ReAct,中文ReAct和ToolBench三种tools system prompt,示例如下
ReAct-EN
```
Answer the following questions as best you can. You have access to the following tools:
{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of [get_current_weather]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Final Answer: the final answer to the original input question
Begin!
```
ReAct-ZH
```
尽你所能回答以下问题。你拥有如下工具:
{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}
以下格式回答:
Thought: 思考你应该做什么
Action: 工具的名称,必须是[get_current_weather]之一
Action Input: 工具的输入
Observation: 工具返回的结果
... (Thought/Action/Action Input/Observation的过程可以重复零次或多次)
Final Answer: 对输入问题的最终答案
开始!
```
ToolBench
```
You can use many tools(functions) to do the following task.
First I will give you the task description, and your task start.
At each step, you need to give your thought to analyze the status now and what to do next, with a function call to actually excute your step. Your output should follow this format:
Thought:
Action:
Action Input:
After the call, you will get the call result, and you are now in a new state.
Then you will analyze your status now, then decide what to do next...
After many (Thought-call) pairs, you finally perform the task, then you can give your finial answer.
Remember:
1.the state change is irreversible, you can\'t go back to one of the former state, if you want to restart the task, say "I give up and restart".
2.All the thought is short, at most in 5 sentence.
3.You can do more then one trys, so if your plan is to continusly try some conditions, you can do one of the conditions per try.
Let\'s Begin!
Task description: You should use functions to help handle the real time user querys. Remember:
1.ALWAYS call "Finish" function at the end of the task. And the final answer should contain enough information to show to the user,If you can\'t handle the task, or you find that function calls always fail(the function is not valid now), use function Finish->give_up_and_restart.
2.Do not use origin tool names, use only subfunctions\' names.
Specifically, you have access to the following APIs: {\'name\': \'get_current_weather\', \'description\': \'Get the current weather in a given location\', \'parameters\': {\'type\': \'object\', \'properties\': {\'location\': {\'type\': \'string\', \'description\': \'The city and state, e.g. San Francisco, CA\'}, \'unit\': {\'type\': \'string\', \'enum\': [\'celsius\', \'fahrenheit\']}}, \'required\': [\'location\']}}
```
默认使用ReAct-EN格式,你也可以在参数中指定`--tools_prompt``react_zh``toolbench` 来选择中文ReAct或ToolBench格式
如果你有更好用的tools system prompt,欢迎告知或贡献给我们。
## 部署
以下以vLLM部署,非流式调用,ReAct prompt为例.
部署Agent时,需要型本身必须具备较强的指令遵循能力,或者已在Agent数据集上进行了训练。如果现有模型未能根据tools字段进行工具选择和参数设置,建议采用更高性能的模型,或者参照[Agent微调实践](./Agent微调最佳实践.md)训练模型
部署模型,这里我们选择`llama3-8b-instruct`模型作为示范
```shell
swift deploy \
--model_type llama3-8b-instruct \
--infer_backend vllm \
```
用curl命令调用接口,因为ReAct格式会以Observation:为结尾,我们需要在stop中指定`Observation:`作为stop words来截断模型回复。有些模型会将`Observation:\n`作为一个token,这里我们也将其作为stop words。
如果你使用ToolBench prompt, 则无需指定stop words(当然加上也没有关系)
```shell
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3-8b-instruct",
"messages": [
{
"role": "user",
"content": "What'\''s the weather like in Boston today?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"stream": false,
"stop": ["Observation:", "Observation:\n"]
}'
```
你也可以通过指定`tool_choice`字段来选择tools中的tool,比如`"tool_choice":{"type": "function", "function": {"name": "my_function"}}`. 默认选择所有tools,也可以设置为None来屏蔽tools字段
调用结果
```json
{"model":"llama3-8b-instruct","choices":[[{"index":0,"message":{"role":"assistant","content":"Question: What's the weather like in Boston today?\n\nThought: I need to get the current weather in Boston to answer this question.\n\nAction: get_current_weather\n\nAction Input: {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\nObservation:","tool_calls":{"id":"toolcall-f534d907ae254f2ab96e06c25179ddf9","function":{"arguments":" {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\n","name":"get_current_weather"},"type":"function"}},"finish_reason":"stop"}]],"usage":{"prompt_tokens":262,"completion_tokens":54,"total_tokens":316},"id":"chatcmpl-8630e8d675c941c0aca958a37633a3c9","object":"chat.completion","created":1717590756}
```
在返回结果的tool_calls中,可以获得调用的函数以及参数信息。
假设调用返回的结果为`The weather in Boston today is 32°F (0°C), with clear skies`, 我们将结果在role tool字段填入message传入
```shell
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3-8b-instruct",
"messages": [
{
"role": "user",
"content": "What'\''s the weather like in Boston today?"
},
{
"role": "assistant",
"content": "Question: What'\''s the weather like in Boston today?\n\nThought: I need to get the current weather in Boston.\n\nAction: get_current_weather\n\nAction Input: {\"location\": \"Boston, MA\", \"unit\": \"fahrenheit\"}\n\nObservation:"
},
{
"role": "tool",
"content": "{\"result\": \"The weather in Boston today is 32°F (0°C), with clear skies\"}\\n\\n"
}
],
"stream": false,
"stop": ["Observation:", "Observation:\n"]
}'
```
对于ReAct格式,我们会将其拼接结果拼接回上一轮模型返回最后的`Observations:`字段之后。
对于ToolBench格式,根据模型template对其处理。如果模型template没有指定对该字段的特殊处理方式,则视为user输入。
如果你有更好用的处理方法,也欢迎告知或贡献给我们。
调用结果
```json
{"model":"llama3-8b-instruct","choices":[{"index":0,"message":{"role":"assistant","content":"\n\nAnswer: The weather in Boston today is 32°F (0°C), with clear skies.","tool_calls":null},"finish_reason":null}],"usage":{"prompt_tokens":93,"completion_tokens":21,"total_tokens":114},"id":"chatcmpl-5e63cee5155f48a48d1366001d16502b","object":"chat.completion","created":1717590962}
```
如果你想要结合代码和tools完成整个链路闭环,推荐阅读[OpenAI教程](https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models)
# Benchmark
## 目录
- [参数设置](#参数设置)
- [量化](#量化)
- [Model Type & Max Length](#model-type--max-length)
- [Batch Size](#batch-size)
- [Use Flash Attn & Gradient Checkpointing](#use-flash-attn--gradient-checkpointing)
- [LoRA Rank & LoRA Target Modules](#lora-rank--lora-target-modules)
- [Gradient Accumulation Steps](#gradient-accumulation-steps)
- [Tuners](#Tuners)
- [Export](#Export)
- [AWQ](#AWQ)
- [AQLM](#AQLM)
- [Sequence Parallel](#Sequence-Parallel)
## 参数设置
实验环境:
- A100
- CUDA 11.8
- python 3.10
- torch 2.1.1
- flash_attn 2.3.4
- xformers 0.0.23
- auto_gptq 0.5.1
- bitsandbytes 0.41.3.post2
以下为所有实验的相同命令行设置部分:
```bash
--dataset_test_ratio 0 \
--dataset cls-fudan-news-zh \
--save_strategy no \
--check_dataset_strategy warning \
--preprocess_num_proc 4 \
```
如果未指定以下参数, 则使用以下默认值:
```bash
--max_length 2048 \
--batch_size 1 \
--gradient_checkpointing true \
--use_flash_attn true \
--lora_rank 8 \
--lora_target_modules DEFAULT \
--quantization_bit 0 \
--gradient_accumulation_steps 16 \
```
对应测试数据集的token数统计量(由qwen的tokenizer获取): 3234.4±2547.5, min=91, max=19548.
实验使用脚本可以查看`scripts/benchmark/test_memory_time/`.
## 量化
测试脚本为:
```bash
swift sft \
--model_type {MODEL_TYPE} \
--quantization_bit {QUANTIZATION_BIT} \
--sft_type lora \
...
```
<table>
<tr>
<td>Model Type [LoRA]</td>
<td>Quantization</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
</tr>
<tr>
<td rowspan="4">qwen-7b-chat</td>
<td>bf16</td>
<td>4.31</td>
<td>27.74</td>
</tr>
<tr>
<td>int4 (gptq)</td>
<td>2.05</td>
<td>19.21</td>
</tr>
<tr>
<td>int8 (gptq)</td>
<td>1.97</td>
<td>22.20</td>
</tr>
<tr>
<td>int4 (bnb)</td>
<td>2.41</td>
<td>23.85</td>
</tr>
<tr>
<td rowspan="4">qwen-14b-chat</td>
<td>bf16</td>
<td>2.60</td>
<td>40.14</td>
</tr>
<tr>
<td>int4 (gptq)</td>
<td>1.15</td>
<td>23.30</td>
</tr>
<tr>
<td>int8 (gptq)</td>
<td>1.08</td>
<td>29.13</td>
</tr>
<tr>
<td>int4 (bnb)</td>
<td>1.36</td>
<td>30.05</td>
</tr>
<tr>
<td rowspan="4">qwen-72b-chat</td>
<td>bf16</td>
<td>0.59 (2*A100)</td>
<td>73.71+78.54</td>
</tr>
<tr>
<td>int4 (gptq)</td>
<td>0.23</td>
<td>54.86</td>
</tr>
<tr>
<td>int8 (gptq)</td>
<td>0.21</td>
<td>78.44</td>
</tr>
<tr>
<td>int4 (bnb)</td>
<td>0.28</td>
<td>74.87</td>
</tr>
</table>
## Model Type & Max Length
### LoRA
测试脚本为:
```bash
swift sft \
--model_type {MODEL_TYPE} \
--max_length {MAX_LENGTH} \
--sft_type lora \
...
```
<table>
<tr>
<td>Model Type [LoRA]</td>
<td>Max Length</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
</tr>
<tr>
<td rowspan="5">qwen-1_8b-chat</td>
<td>512</td>
<td>9.88</td>
<td>6.99</td>
</tr>
<tr>
<td>1024</td>
<td>9.90</td>
<td>10.71</td>
</tr>
<tr>
<td>2048</td>
<td>8.77</td>
<td>16.35</td>
</tr>
<tr>
<td>4096</td>
<td>5.92</td>
<td>23.80</td>
</tr>
<tr>
<td>8192</td>
<td>4.19</td>
<td>37.03</td>
</tr>
<tr>
<td rowspan="5">qwen-7b-chat</td>
<td>512</td>
<td>7.43</td>
<td>18.01</td>
</tr>
<tr>
<td>1024</td>
<td>6.51</td>
<td>21.73</td>
</tr>
<tr>
<td>2048</td>
<td>4.31</td>
<td>27.74</td>
</tr>
<tr>
<td>4096</td>
<td>2.05</td>
<td>35.31</td>
</tr>
<tr>
<td>8192</td>
<td>1.34</td>
<td>48.41</td>
</tr>
<tr>
<td rowspan="5">qwen-14b-chat</td>
<td>512</td>
<td>5.63</td>
<td>30.14</td>
</tr>
<tr>
<td>1024</td>
<td>4.36</td>
<td>34.43</td>
</tr>
<tr>
<td>2048</td>
<td>2.60</td>
<td>40.14</td>
</tr>
<tr>
<td>4096</td>
<td>1.17</td>
<td>47.95</td>
</tr>
<tr>
<td>8192</td>
<td>0.79</td>
<td>60.74</td>
</tr>
<tr>
<td rowspan="5">qwen-72b-chat (2*A100)</td>
<td>512</td>
<td>1.41</td>
<td>67.68+73.07</td>
</tr>
<tr>
<td>1024</td>
<td>1.02</td>
<td>70.25+77.11</td>
</tr>
<tr>
<td>2048</td>
<td>0.59</td>
<td>73.71+78.54</td>
</tr>
<tr>
<td>4096</td>
<td>-</td>
<td>OOM</td>
</tr>
<tr>
<td>8192</td>
<td>-</td>
<td>OOM</td>
</tr>
<tr>
<td rowspan="5">chatglm3-6b</td>
<td>512</td>
<td>6.72</td>
<td>13.94</td>
</tr>
<tr>
<td>1024</td>
<td>6.16</td>
<td>12.99</td>
</tr>
<tr>
<td>2048</td>
<td>4.20</td>
<td>17.20</td>
</tr>
<tr>
<td>4096</td>
<td>1.92</td>
<td>29.80</td>
</tr>
<tr>
<td>8192</td>
<td>1.24</td>
<td>66.82</td>
</tr>
<tr>
<td rowspan="5">yi-6b-chat</td>
<td>512</td>
<td>5.27</td>
<td>13.72</td>
</tr>
<tr>
<td>1024</td>
<td>5.07</td>
<td>15.44</td>
</tr>
<tr>
<td>2048</td>
<td>3.84</td>
<td>16.95</td>
</tr>
<tr>
<td>4096</td>
<td>1.99</td>
<td>28.25</td>
</tr>
<tr>
<td>8192</td>
<td>1.35</td>
<td>43.81</td>
</tr>
<tr>
<td rowspan="5">yi-34b-chat</td>
<td>512</td>
<td>2.32</td>
<td>66.72</td>
</tr>
<tr>
<td>1024</td>
<td>1.76</td>
<td>69.10</td>
</tr>
<tr>
<td>2048</td>
<td>1.05</td>
<td>71.34</td>
</tr>
<tr>
<td>4096</td>
<td>0.47</td>
<td>78.72</td>
</tr>
<tr>
<td>8192</td>
<td>0.31 (2*A100)</td>
<td>47.01+65.03</td>
</tr>
<tr>
<td rowspan="5">openbuddy-zephyr-7b-chat</td>
<td>512</td>
<td>5.17</td>
<td>14.99</td>
</tr>
<tr>
<td>1024</td>
<td>3.92</td>
<td>16.57</td>
</tr>
<tr>
<td>2048</td>
<td>3.08</td>
<td>19.89</td>
</tr>
<tr>
<td>4096</td>
<td>1.85</td>
<td>23.29</td>
</tr>
<tr>
<td>8192</td>
<td>0.92</td>
<td>52.14</td>
</tr>
<tr>
<td rowspan="5">baichuan2-7b-chat</td>
<td>512</td>
<td>6.09</td>
<td>18.18</td>
</tr>
<tr>
<td>1024</td>
<td>5.36</td>
<td>17.45</td>
</tr>
<tr>
<td>2048</td>
<td>3.43</td>
<td>19.18</td>
</tr>
<tr>
<td>4096</td>
<td>1.69</td>
<td>34.22</td>
</tr>
<tr>
<td>8192</td>
<td>1.16</td>
<td>45.47</td>
</tr>
<tr>
<td rowspan="5">baichuan2-13b-chat</td>
<td>512</td>
<td>5.32</td>
<td>31.01</td>
</tr>
<tr>
<td>1024</td>
<td>3.91</td>
<td>31.58</td>
</tr>
<tr>
<td>2048</td>
<td>1.77</td>
<td>32.40</td>
</tr>
<tr>
<td>4096</td>
<td>0.65</td>
<td>49.63</td>
</tr>
<tr>
<td>8192</td>
<td>0.36</td>
<td>76.17</td>
</tr>
</table>
### Full
测试脚本为:
```bash
swift sft \
--model_type {MODEL_TYPE} \
--max_length {MAX_LENGTH} \
--sft_type full \
...
```
<table>
<tr>
<td>Model Type [FULL]</td>
<td>Max Length</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
</tr>
<tr>
<td rowspan="5">qwen-1_8b-chat</td>
<td>512</td>
<td>10.77</td>
<td>18.16</td>
</tr>
<tr>
<td>1024</td>
<td>10.39</td>
<td>18.62</td>
</tr>
<tr>
<td>2048</td>
<td>8.73</td>
<td>35.11</td>
</tr>
<tr>
<td>4096</td>
<td>5.45</td>
<td>31.62</td>
</tr>
<tr>
<td>8192</td>
<td>3.81</td>
<td>38.93</td>
</tr>
<tr>
<td rowspan="5">qwen-7b-chat</td>
<td>512</td>
<td>5.96</td>
<td>73.37</td>
</tr>
<tr>
<td>1024</td>
<td>5.00</td>
<td>73.64</td>
</tr>
<tr>
<td>2048</td>
<td>3.30</td>
<td>74.26</td>
</tr>
<tr>
<td>4096</td>
<td>1.64</td>
<td>78.76</td>
</tr>
<tr>
<td>8192</td>
<td>1.11 (2*A100)</td>
<td>61.34+73.00</td>
</tr>
<tr>
<td rowspan="5">qwen-14b-chat (2*A100)</td>
<td>512</td>
<td>3.66</td>
<td>60.42+72.31</td>
</tr>
<tr>
<td>1024</td>
<td>2.98</td>
<td>60.61+74.37</td>
</tr>
<tr>
<td>2048</td>
<td>1.93</td>
<td>60.70+78.22</td>
</tr>
<tr>
<td>4096</td>
<td>0.92</td>
<td>75.59+78.64</td>
</tr>
<tr>
<td>8192</td>
<td>0.62</td>
<td>76.59+77.68</td>
</tr>
</table>
## Batch Size
测试脚本为:
```bash
swift sft \
--batch_size {BATCH_SIZE} \
--model_type qwen-7b-chat \
--sft_type lora \
...
```
<table>
<tr>
<td>Model Type [LoRA]</td>
<td>Batch Size</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
</tr>
<tr>
<td rowspan="4">qwen-7b-chat</td>
<td>1</td>
<td>4.31</td>
<td>27.74</td>
</tr>
<tr>
<td>2</td>
<td>3.60</td>
<td>43.11</td>
</tr>
<tr>
<td>4</td>
<td>3.02</td>
<td>63.81</td>
</tr>
<tr>
<td>8</td>
<td>2.77</td>
<td>76.14</td>
</tr>
</table>
## Use Flash Attn & Gradient Checkpointing
测试脚本为:
```bash
swift sft \
--use_flash_attn {USE_FLASH_ATTN} \
--gradient_checkpointing {GRADIENT_CHECKPOINTING} \
--model_type qwen-7b-chat \
--sft_type lora \
...
```
<table>
<tr>
<td>Model Type [LoRA]</td>
<td>Use Flash Attn</td>
<td>Gradient Checkpointing</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
</tr>
<tr>
<td rowspan="4">qwen-7b-chat</td>
<td>&#x2714;</td>
<td>&#x2714;</td>
<td>4.31</td>
<td>27.74</td>
</tr>
<tr>
<td>&#x2714;</td>
<td>&#x2718;</td>
<td>6.19</td>
<td>37.70</td>
</tr>
<tr>
<td>&#x2718;</td>
<td>&#x2714;</td>
<td>3.13</td>
<td>27.71</td>
</tr>
<tr>
<td>&#x2718;</td>
<td>&#x2718;</td>
<td>4.45</td>
<td>57.67</td>
</tr>
</table>
## LoRA Rank & LoRA Target Modules
测试脚本为:
```bash
swift sft \
--lora_rank {LORA_RANK} \
--lora_target_modules {LORA_TARGET_MODULES} \
--model_type qwen-7b-chat \
--sft_type lora \
...
```
<table>
<tr>
<td>Model Type [LoRA]</td>
<td>LoRA Rank</td>
<td>LoRA Target Modules</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
<td>Trainable Params (M)</td>
</tr>
<tr>
<td rowspan="4">qwen-7b-chat</td>
<td>2</td>
<td>DEFAULT (c_attn)</td>
<td>4.27</td>
<td>27.72</td>
<td>1.05</td>
</tr>
<tr>
<td>8</td>
<td>DEFAULT</td>
<td>4.31</td>
<td>27.74</td>
<td>4.19</td>
</tr>
<tr>
<td>64</td>
<td>DEFAULT</td>
<td>4.19</td>
<td>27.85</td>
<td>33.55</td>
</tr>
<tr>
<td>8</td>
<td>ALL (all linear)</td>
<td>3.22</td>
<td>27.87</td>
<td>17.89</td>
</tr>
</table>
## Gradient Accumulation Steps
测试脚本为:
```bash
swift sft \
--gradient_accumulation_steps {GRADIENT_ACCUMULATION_STEPS} \
--model_type qwen-7b-chat \
--sft_type lora \
...
```
<table>
<tr>
<td>Model Type [LoRA]</td>
<td>Gradient Accumulation Steps</td>
<td>Training Speed (samples/s)</td>
<td>GPU Memory (GiB)</td>
</tr>
<tr>
<td rowspan="7">qwen-7b-chat</td>
<td>1</td>
<td>4.26</td>
<td>27.73</td>
</tr>
<tr>
<td>2</td>
<td>4.32</td>
<td>27.74</td>
</tr>
<tr>
<td>4</td>
<td>4.31</td>
<td>27.74</td>
</tr>
<tr>
<td>8</td>
<td>4.32</td>
<td>27.74</td>
</tr>
<tr>
<td>16</td>
<td>4.33</td>
<td>27.74</td>
</tr>
<tr>
<td>32</td>
<td>4.30</td>
<td>27.74</td>
</tr>
<tr>
<td>64</td>
<td>4.32</td>
<td>27.74</td>
</tr>
</table>
## Tuners
| exp_name | model_type | dataset | ms-bench mix ratio | tuner | tuner_params | trainable params(M) | flash_attn | gradient_checkpointing | hypers | memory | train speed(samples/s) | infer speed(tokens/s) | train_loss | eval_loss | gsm8k weighted acc | arc weighted acc | ceval weighted acc |
| -------- | ---------- | ------- | -------------------| ----- | ------------ | ------------------- | -----------| ---------------------- | ------ | ------ | ---------------------- | --------------------- | ---------- | --------- | ------------------ | ---------------- | ------------------ |
|adalora|qwen-7b-chat|ms-agent|2.0|adalora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False|26.8389(0.3464%)|True|True|lr=5e-05/epoch=2|32.55GiB|0.92(87543 samples/95338.71 seconds)|17.33(2345 tokens/135.29 seconds)|0.57|1.07|0.391|0.665|0.569|
|adapter|qwen-7b-chat|ms-agent|2.0|adapter||33.6896(0.4344%)|True|True|lr=5e-05/epoch=2|32.19GiB|1.48(87543 samples/59067.71 seconds)|26.63(4019 tokens/150.90 seconds)|0.55|1.03|0.438|0.662|0.565|
|dora|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=True|19.2512(0.2487%)|True|True|lr=5e-05/epoch=2|32.46GiB|0.51(87543 samples/171110.54 seconds)|4.29(2413 tokens/562.32 seconds)|0.53|1.01|0.466|0.683|**0.577**|
|full+galore128|qwen-7b-chat|ms-agent|2.0|full|galore_rank=128/galore_per_parameter=false/galore_with_embedding=false|7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|47.02GiB|1.10(87543 samples/79481.96 seconds)|28.96(2400 tokens/82.88 seconds)|0.55|1.00|0.358|**0.688**|**0.577**|
|full+galore32|qwen-7b-chat|ms-agent|2.0|full|galore_rank=32/galore_per_parameter=false/galore_with_embedding=false|7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|47.05GiB|1.11(87543 samples/78989.74 seconds)|29.17(2431 tokens/83.35 seconds)|0.56|1.01|0.386|0.667|0.539|
|full+galore64|qwen-7b-chat|ms-agent|2.0|full|galore_rank=64/galore_per_parameter=false/galore_with_embedding=false|7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|46.91GiB|1.11(87543 samples/79200.36 seconds)|28.94(2448 tokens/84.60 seconds)|0.56|1.01|0.397|0.674|0.544|
|full+galore_emb|qwen-7b-chat|ms-agent|2.0|full|galore_rank=128/galore_per_parameter=false/galore_with_embedding=true|7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|44.53GiB|1.10(87543 samples/79775.02 seconds)|29.45(2433 tokens/82.62 seconds)|0.55|1.00|0.398|0.670|0.568|
|full+galore_perparam|qwen-7b-chat|ms-agent|2.0|full|galore_rank=128/galore_per_parameter=true/galore_with_embedding=false|7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|47.02GiB|1.25(87543 samples/69821.89 seconds)|29.02(2478 tokens/85.39 seconds)|0.54|1.00|0.372|0.669|0.524|
|full+no_mix|qwen-7b-chat|ms-agent|0.0|full||7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|72.56GiB|1.27(29698 samples/23356.97 seconds)|30.31(11738 tokens/387.29 seconds)|0.57|**0.44**|0.174|0.652|0.553|
|full|qwen-7b-chat|ms-agent|2.0|full||7721.3245(100.0000%)|True|True|lr=5e-05/epoch=2|73.53GiB|1.43(87543 samples/61022.97 seconds)|29.51(3382 tokens/114.62 seconds)|0.54|0.95|0.343|0.536|0.495|
|llamapro|qwen-7b-chat|ms-agent|2.0|llamapro|num_blocks=4|809.5826(9.4900%)|True|True|lr=5e-05/epoch=2|38.11GiB|1.53(87543 samples/57294.42 seconds)|25.80(2374 tokens/92.02 seconds)|0.53|1.00|0.434|0.645|0.357|
|lora+|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=16.0/use_rslora=False/use_dora=False|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|32.35GiB|0.95(87543 samples/91923.80 seconds)|18.81(3329 tokens/176.94 seconds)|0.53|0.98|0.432|0.647|0.344|
|lora+neftune|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False/neftune_noise_alpha=15.0|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|32.35GiB|0.96(87543 samples/91525.50 seconds)|19.84(161792 tokens/8156.02 seconds)|0.53|1.02|0.456|0.671|0.401|
|lora+no_mix|qwen-7b-chat|ms-agent|0.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|30.86GiB|0.91(29698 samples/32570.15 seconds)|19.89(36308 tokens/1825.26 seconds)|0.53|0.53|0.470|0.666|0.574|
|lora|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|32.35GiB|0.95(87543 samples/91974.29 seconds)|18.11(2415 tokens/133.32 seconds)|0.53|1.01|0.462|0.676|0.304|
|qwen-7b-chat-eval|qwen-7b-chat|None|0.0|None||None(None)||||None||30.81(13765 tokens/446.83 seconds)|||**0.517**|0.679|0.568|
|rslora|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=True/use_dora=False|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|32.35GiB|0.94(87543 samples/92758.63 seconds)|18.87(2762 tokens/146.34 seconds)|**0.53**|0.99|0.451|0.679|0.339|
| full+lisa_2 | qwen-7b-chat | ms-agent | 2.0 | full | lisa_activated_layers=2/lisa_step_interval=20 | - | True | True | lr=5e-05/epoch=2 | 31.11GiB | 2.66(76837 samples/28881.28 seconds) | 36.10(134469 tokens/3725.21 seconds) | 0.62 | 1.06 | 0.349 | 0.653 | 0.592 |
| full+lisa_4 | qwen-7b-chat | ms-agent | 2.0 | full | lisa_activated_layers=4/lisa_step_interval=20 | - | True | True | lr=5e-05/epoch=2 | 31.87GiB | 2.63(76837 samples/29215.15 seconds) | 36.75(135477 tokens/3686.17 seconds) | 0.63 | 1.06 | 0.377 | 0.656 | **0.607** |
|lora+packing+ddp|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False/packing=True|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|35.65GiB*2|1.56(7900 samples/5057.30 seconds)|26.20(421094 tokens/16073.09 seconds)|0.63|0.98|0.473|0.664|0.552|
|lora+packing+lazytokenize|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False/packing=True|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|32.83GiB|7.69(78237 samples/10179.40 seconds)|25.86(307390 tokens/11888.17 seconds)|0.63|1.04|0.472|0.660|0.554|
|lora+packing|qwen-7b-chat|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False/packing=True|17.8913(0.2312%)|True|True|lr=5e-05/epoch=2|28.06GiB|0.79(7900 samples/10048.53 seconds)|26.12(409507 tokens/15675.36 seconds)|0.61|0.95|0.492|0.676|0.539|
## unsloth
| exp_name | model_type | dataset | ms-bench mix ratio | tuner | tuner_params | trainable params(M) | flash_attn | gradient_checkpointing | hypers | memory | train speed(samples/s) | infer speed(tokens/s) | train_loss | eval_loss | gsm8k weighted acc | arc weighted acc | ceval weighted acc |
| --------------- | ------------------ | -------- | ------------------ | ----- | ------------ | ------------------- | ---------- | ---------------------- | ---------------- | -------- | ------------------------------------ | ------------------------------------- | ---------- | --------- | ------------------ | ---------------- | ------------------ |
| unsloth+lora+q4 | llama3-8b-instruct | ms-agent | 2.0 | lora | | 4.7186(0.1038%) | True | True | lr=5e-05/epoch=2 | 21.69GiB | 1.76(76839 samples/43763.01 seconds) | 15.22(160885 tokens/10570.90 seconds) | 0.58 | 1.03 | 0.668 | 0.755 | 0.501 |
## Export
| exp_name | model_type | calibration dataset | quantization method | quantization bits | infer speed(tokens/s) | gsm8k weighted acc | arc weighted acc | ceval weighted acc |
| -------- | ---------- | ------------------- | ------------------- | ----------------- | --------------------- | ------------------ | ---------------- | ------------------ |
|awq-ms-bench-mini|qwen-7b-chat|ms-bench-mini|awq|4|27.25(16501 tokens/605.47 seconds)|0.494|0.665|0.571|
|awq-pileval|qwen-7b-chat|pileval|awq|4|26.92(12994 tokens/482.72 seconds)|**0.497**|**0.675**|**0.577**|
|gptq-ms-bench-mini|qwen-7b-chat|ms-bench-mini|gptq|4|31.16(15349 tokens/492.54 seconds)|0.482|0.642|0.556|
|gptq-pileval|qwen-7b-chat|pileval|gptq|4|31.67(15185 tokens/479.54 seconds)|0.478|0.654|0.559|
## AWQ
| exp_name | model_type | dataset | ms-bench mix ratio | tuner | tuner_params | trainable params(M) | flash_attn | gradient_checkpointing | hypers | memory | train speed(samples/s) | infer speed(tokens/s) | train_loss | eval_loss | gsm8k weighted acc | arc weighted acc | ceval weighted acc |
| -------- | ---------- | ------- | -------------------| ----- | ------------ | ------------------- | -----------| ---------------------- | ------ | ------ | ---------------------- | --------------------- | ---------- | --------- | ------------------ | ---------------- | ------------------ |
|qwen1half-7b-chat-awq|qwen1half-7b-chat-awq|ms-agent|2.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False|19.9885(1.5802%)|True|True|lr=5e-05/epoch=2|24.26GiB|0.45(87543 samples/194746.58 seconds)|16.08(2469 tokens/153.58 seconds)|**0.55**|**1.19**|**0.505**|**0.737**|**0.656**|
## AQLM
| exp_name | model_type | dataset | ms-bench mix ratio | tuner | tuner_params | trainable params(M) | flash_attn | gradient_checkpointing | hypers | memory | train speed(samples/s) | infer speed(tokens/s) | train_loss | eval_loss | gsm8k weighted acc | arc weighted acc | ceval weighted acc |
| -------- | ---------- | ------- | -------------------| ----- | ------------ | ------------------- | -----------| ---------------------- | ------ | ------ | ---------------------- | --------------------- | ---------- | --------- | ------------------ | ---------------- | ------------------ |
|llama2-7b-aqlm-2bit-1x16|llama2-7b-aqlm-2bit-1x16|dureader-robust-zh|0.0|lora|rank=8/target=ALL/alpha=32/lr_ratio=None/use_rslora=False/use_dora=False|19.9885(1.6510%)|True|True|lr=5e-05/epoch=2|4.04GiB|0.17(14994 samples/86140.71 seconds)||**0.48**|**0.74**||||
## Sequence Parallel
<table>
<tr>
<td>Model</td>
<td>Dataset</td>
<td>Hyper params</td>
<td>Total steps</td>
<td>Train speed</td>
<td>Gpu memory</td>
</tr>
<tr>
<td rowspan="4">chatglm3-6b-32k</td>
<td rowspan="4">long-alpaca-12k(8055 tokens * 12000 rows)</td>
<td>gpu=2/sequence_parallel_size=1(双GPU DDP基准测试)</td>
<td>5940</td>
<td>0.30iter/s(5h13min total)</td>
<td>27G*2</td>
</tr>
<tr>
<td>gpu=2/sequence_parallel_size=2(双GPU序列并行2)</td>
<td>11880</td>
<td>0.5iter/s(6h total)</td>
<td>20G*2</td>
</tr>
<tr>
<td>gpu=4/sequence_parallel_size=4(四GPU序列并行4)</td>
<td>11880</td>
<td>1iter/s(3h20min total)</td>
<td>18G*4</td>
</tr>
<tr>
<td>gpu=4/sequence_parallel_size=2(四GPU序列并行2)</td>
<td>5940</td>
<td>0.45iter/s(3h total)</td>
<td>21G*4</td>
</tr>
</table>
# DPO训练文档
## 目录
- [环境准备](#环境准备)
- [人类对齐训练](#人类对齐训练)
## 环境准备
GPU设备: A10, 3090, V100, A100均可,如果是显存<=24G的GPU最少需要双卡环境。由于人类对齐训练在一张卡上加载两个模型,因此比微调的显存多占用一个推理模型的显存使用量。
```bash
# 设置pip全局镜像 (加速下载)
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
# 安装ms-swift
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
## 人类对齐训练
下面的shell脚本运行了一个人类对齐训练。首先需要切换到运行目录:
```shell
cd examples/pytorch/llm
```
运行下面的命令:
```shell
# Experimental environment: 4*A100
# Memory usage: 4 * 20G,双卡device_map * 2ddp
nproc_per_node=2
CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift dpo \
--model_type yi-6b-chat \
--ref_model_type yi-6b-chat \
--model_revision master \
--sft_type lora \
--tuner_backend swift \
--dtype AUTO \
--output_dir output \
--dataset hh-rlhf-cn:harmless_base_cn \
--num_train_epochs 3 \
--max_length 1024 \
--max_prompt_length 512 \
--check_dataset_strategy none \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
--batch_size 1 \
--weight_decay 0.1 \
--learning_rate 5e-5 \
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
--max_grad_norm 1.0 \
--warmup_ratio 0.03 \
--eval_steps 2000 \
--save_steps 2000 \
--save_total_limit 2 \
--logging_steps 10 \
```
### sh脚本
sh脚本可以查看[这里](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/dpo)
```bash
# 下面的脚本需要在此目录下执行
cd examples/pytorch/llm
```
**提示**:
- 如果用带有history的数据训练base模型,需要指定支持多轮对话的template(base模型往往不支持多轮对话),对于这种情况我们默认设置了`chatml`template,你也可以支持--model_type 来选择训练模型的template
- 我们默认在训练时设置`--gradient_checkpointing true`**节约显存**, 这会略微降低训练速度.
- 如果你使用的是**V100**等较老的GPU, 你需要设置`--dtype AUTO`或者`--dtype fp16`, 因为其不支持bf16.
- 如果你的机器是A100等高性能显卡, 且使用的是qwen系列模型, 推荐你安装[**flash-attn**](https://github.com/Dao-AILab/flash-attention), 这将会加快训练和推理的速度以及显存占用(A10, 3090, V100等显卡不支持flash-attn进行训练). 支持flash-attn的模型可以查看[LLM支持的模型](支持的模型和数据集.md#模型)
- 如果你需要断网进行训练, 请使用`--model_id_or_path <model_dir>`和设置`--check_model_is_latest false`. 具体参数含义请查看[命令行参数](命令行参数.md).
- 如果你想在训练时, 将权重push到ModelScope Hub中, 你需要设置`--push_to_hub true`.
```bash
# dpo训练 mistral-7b max_length=1024,bs=1
# 推荐的实验环境: V100, A10, 3090,2卡4卡或8卡
bash scripts/dpo/lora_ddp_mp/dpo.sh
bash scripts/dpo/lora_ddp_mp/infer.sh
```
由于DPO训练后会得到一个完整模型或者adapter的weights,因此LoRA合并、推理的步骤和微调步骤相同,因此请参考[微调文档](LLM微调文档.md#merge-lora)对应的步骤。
# Grok 300B训练和推理实战
本文介绍了使用8卡环境对Grok-MoE 300B模型进行微调和推理的流程。
## 目录
- [环境准备](#环境准备)
- [微调](#微调)
- [推理](#推理)
## 环境准备
```shell
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
```
## 微调
### 实验环境
- GPU:8*A100 80G
- 镜像:ModelScope官方镜像1.13.1版本
- peft:0.10.0
### 数据集准备
Grok是base模型,因此我们使用了[问题生成数据集DuReader](https://www.modelscope.cn/datasets/modelscope/DuReader_robust-QG/summary)作为训练集。该数据集约15000条,max-length设置为512,训练数据约10000条(平均长度305±92 tokens)。
### 模型准备
Grok模型我们使用了[ColossalAI提供的版本](https://www.modelscope.cn/models/colossalai/grok-1-pytorch/summary),其中我们额外准备了[符合transformers标准的tokenizer](https://www.modelscope.cn/models/AI-ModelScope/grok-1-tokenizer/summary)
### 训练
由于Grok模型过大,device_map和deepspeed zero3非offload均无法运行训练,因此本次实验我们使用了LoRA+deepspeed zero3 offload模式运行训练。训练完整脚本如下:
```shell
# cd examples/pytorch/llm first
nproc_per_node=8
PYTHONPATH=../../.. \
torchrun \
--nproc_per_node=$nproc_per_node \
--master_port 29500 \
llm_sft.py \
--model_type grok-1 \
--sft_type lora \
--tuner_backend peft \
--dtype bf16 \
--output_dir output \
--ddp_backend nccl \
--dataset dureader-robust-zh \
--train_dataset_sample -1 \
--num_train_epochs 1 \
--max_length 512 \
--check_dataset_strategy warning \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_dtype AUTO \
--lora_target_modules DEFAULT \
--gradient_checkpointing true \
--batch_size 2 \
--weight_decay 0.1 \
--learning_rate 1e-4 \
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
--max_grad_norm 0.5 \
--warmup_ratio 0.03 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10 \
--deepspeed zero3-offload \
```
完整的训练文件可以在[这里](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/grok-1/lora_ddp_ds)找到。
下面是训练过程的一些benchmark:
| 指标 | 值 |
| -------- | ------------------------------------------------------------ |
| 显存占用 | 8*21G |
| 训练速度 | 45s/it |
| 总iter数 | 340(10000(dataset_length)/16(gradient_accumulation)/2(batch_size)) |
<img src="../../resources/image-20240329122854204.png" alt="image-20240329122854204" style="zoom: 33%;" />
由于显存占用不到24G,理论上可以在RTX3090/A10环境中运行训练。
<img src="../../resources/grok_train_loss.png" alt="train_loss (1)" style="zoom:33%;" />
<img src="../../resources/grok_train_acc.png" alt="train_acc" style="zoom:33%;" />
训练时长约4小时。
### 推理
SWIFT框架目前并不支持deepspeed推理,因此我们仍然使用transformers的device_map进行推理支持。但由于模型过大,因此部分layers会被offload到CPU上,并影响LoRA加载使推理出错,因此我们针对peft的实现进行了一定patch(原Linear在meta设备上时不迁移LoRA,并在运行时动态迁移weights)。
推理脚本如下:
```shell
# cd examples/pytorch/llm first
PYTHONPATH=../../.. \
python llm_infer.py \
--ckpt_dir output/grok-1/vx-xxx-xxx/checkpoint-xxx \
--dtype bf16 \
--load_dataset_config true \
--max_new_tokens 64 \
--do_sample true \
--dtype bf16 \
--eval_human false \
--merge_lora false \
```
推理结果:
```text
[PROMPT]Task: Question Generation
Context: 我个人感觉是吕颂贤版,剧情和原著差别不大,虽然TVB演员颜值和风光没有大陆的好。但是香港特区人口和地域的限制,只能注重在演员的演技方面发挥很出色,楼主看过大陆排《笑傲江湖》吧!在台词上表现的很生硬没有香港的注重神色配台词,比如杜燕歌把吕颂贤表情和性格几乎和原著差别不大。武打几乎沿用徐克和程小东动作的风格很注重实际技巧,没有大陆版的在武打场面依靠电脑特效表现的太夸张了。李亚鹏版的武打动作和导演还是香港的元彬,大陆毕竟还是在武侠剧起步的比较晚,主要是还是靠明星大腕压阵而香港却是恰恰相反。
Answer: 吕颂贤版
Question:[OUTPUT]笑傲江湖哪个版本好看</s>
[LABELS]笑傲江湖哪个版本好看
--------------------------------------------------
[PROMPT]Task: Question Generation
Context: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。
Answer: 6-12周
Question:[OUTPUT]怀孕几个月开始反应</s>
[LABELS]怀孕多久会有反应
--------------------------------------------------
```
# HuggingFace生态兼容
默认我们会使用[ModelScope](https://modelscope.cn/my/overview)中的模型和数据集进行微调和推理。但是考虑到海外用户更熟悉[HuggingFace](https://huggingface.co/)生态,这里对其进行兼容。
你需要设置环境变量`USE_HF=1`,支持的HuggingFace模型和数据集可以参考[支持的模型和数据集](支持的模型和数据集.md),部分数据集只支持在ModelScope环境下使用。
以下是对`qwen1.5-7b-chat`的推理脚本:
```shell
# Experimental Environment: A10, 3090, V100
USE_HF=1 CUDA_VISIBLE_DEVICES=0 swift infer --model_type qwen1half-7b-chat
```
微调脚本:
```shell
# Experimental Environment: 2 * A100
# GPU Memory Requirement: 2 * 30GB
USE_HF=1 \
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
swift sft \
--model_type qwen1half-7b-chat \
--dataset blossom-math-zh \
--num_train_epochs 5 \
--sft_type lora \
--output_dir output \
```
微调后推理与部署等内容参考其他文档.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment