Commit 24534501 authored by mashun1's avatar mashun1
Browse files

parallel_tool

parent c4ba4563
...@@ -176,3 +176,4 @@ output/ ...@@ -176,3 +176,4 @@ output/
wandb/ wandb/
swanlog/ swanlog/
generated_predictions.jsonl generated_predictions.jsonl
predictions_score.json
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-429-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![Citation](https://img.shields.io/badge/citation-476-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
...@@ -18,15 +18,27 @@ ...@@ -18,15 +18,27 @@
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) [![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
<h3 align="center"> ### 获得[亚马逊](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)、[英伟达](https://developer.nvidia.cn/rtx/ai-toolkit)、[阿里云](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory)等的应用。
使用零代码<a href="#快速开始">命令行</a><a href="#llama-board-可视化微调由-gradio-驱动">Web UI</a> 轻松微调百余种大模型
</h3>
<p align="center">
<picture>
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
</picture>
</p>
<div align="center" markdown="1">
### 赞助商 ❤️
<a href="https://warp.dev/llama-factory">
<img alt="Warp sponsorship" width="400" src="https://github.com/user-attachments/assets/ab8dd143-b0fd-4904-bdc5-dd7ecac94eae">
</a>
#### [Warp,面向开发者的智能终端](https://warp.dev/llama-factory)
[适用于 MacOS、Linux 和 Windows](https://warp.dev/llama-factory)
----
### 使用零代码[命令行](#快速开始)与 [Web UI](#llama-board-可视化微调由-gradio-驱动) 轻松微调百余种大模型
![GitHub Trend](https://trendshift.io/api/badge/repositories/4535)
</div>
👋 加入我们的[微信群](assets/wechat.jpg)[NPU 用户群](assets/wechat_npu.jpg) 👋 加入我们的[微信群](assets/wechat.jpg)[NPU 用户群](assets/wechat_npu.jpg)
...@@ -43,9 +55,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -43,9 +55,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/ - **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
- **PAI-DSW(免费试用)**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) - **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Amazon SageMaker**[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
- **Easy Dataset**[数据蒸馏微调](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)
> [!NOTE] > [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。 > 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
...@@ -53,7 +63,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -53,7 +63,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
## 目录 ## 目录
- [项目特色](#项目特色) - [项目特色](#项目特色)
- [性能指标](#性能指标) - [官方博客](#官方博客)
- [更新日志](#更新日志) - [更新日志](#更新日志)
- [模型](#模型) - [模型](#模型)
- [训练方法](#训练方法) - [训练方法](#训练方法)
...@@ -93,18 +103,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -93,18 +103,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | | Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
## 性能指标 ## 官方博客
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。 - [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
- [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
![benchmark](assets/benchmark.svg) <details><summary>全部博客</summary>
<details><summary>变量定义</summary> - [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024) - [LLaMA Factory:微调LLaMA3模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文)
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA Factory 的 LoRA 微调中采用 `lora_rank=32`
</details> </details>
...@@ -236,6 +245,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -236,6 +245,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
</details> </details>
> [!TIP]
> 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。
## 模型 ## 模型
| 模型名 | 参数量 | Template | | 模型名 | 参数量 | Template |
...@@ -246,17 +258,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -246,17 +258,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\* | 1B/2B/8B/14B/38B/78B | intern_vl | | [InternVL 2.5-3](https://huggingface.co/OpenGVLab) | 1B/2B/8B/14B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
...@@ -266,6 +278,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -266,6 +278,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
...@@ -281,8 +294,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -281,8 +294,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 | | [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen)\*\* | 7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Seed Coder](https://huggingface.co/ByteDance-Seed) | 8B | seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | | [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
...@@ -405,6 +419,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ...@@ -405,6 +419,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [COIG-P (en&zh)](https://huggingface.co/datasets/m-a-p/COIG-P) - [COIG-P (en&zh)](https://huggingface.co/datasets/m-a-p/COIG-P)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset) - [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback) - [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
...@@ -426,6 +441,7 @@ huggingface-cli login ...@@ -426,6 +441,7 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.9 | 3.10 |
| torch | 2.0.0 | 2.6.0 | | torch | 2.0.0 | 2.6.0 |
| torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.45.0 | 4.50.0 | | transformers | 4.45.0 | 4.50.0 |
| datasets | 2.16.0 | 3.2.0 | | datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 | | accelerate | 0.34.0 | 1.2.1 |
...@@ -463,13 +479,13 @@ huggingface-cli login ...@@ -463,13 +479,13 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]" --no-build-isolation
``` ```
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality 可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install -e . --no-deps --no-build-isolation` 解决。
<details><summary>使用 <b>uv</b> 构建虚拟环境</summary> <details><summary>使用 <b>uv</b> 构建虚拟环境</summary>
...@@ -487,9 +503,22 @@ uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora ...@@ -487,9 +503,22 @@ uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora
</details> </details>
<details><summary>Windows 用户指南</summary> <details><summary>Windows 用户指南</summary>
#### 安装 PyTorch
Windows 平台需要额外手动安装 GPU 版本的 PyTorch 依赖包,您可以参考[官方网站](https://pytorch.org/get-started/locally/)和以下命令安装并测试 PyTorch 是否正确安装。
```bash
pip uninstall torch torchvision torchaudio
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
python -c "import torch; print(torch.cuda.is_available())"
```
如果看到 `True` 则说明安装成功。
若遇到类似 `Can't pickle local object` 的报错,请设置 `dataloader_num_workers: 0`
#### 安装 BitsAndBytes #### 安装 BitsAndBytes
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) 如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)
...@@ -579,7 +608,7 @@ pip install . ...@@ -579,7 +608,7 @@ pip install .
> [!NOTE] > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。 > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)** 构建用于微调的合成数据。 您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)** **[GraphGen](https://github.com/open-sciencelab/GraphGen)** 构建用于微调的合成数据。
### 快速开始 ### 快速开始
...@@ -890,6 +919,7 @@ swanlab_run_name: test_run # 可选 ...@@ -890,6 +919,7 @@ swanlab_run_name: test_run # 可选
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357) 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。 1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。 1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。
1. **[WeClone](https://github.com/xming521/WeClone)**:从聊天记录创造数字分身的一站式解决方案。
</details> </details>
......
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="350.696449pt" height="268.034375pt" viewBox="0 0 350.696449 268.034375" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2023-11-18T11:28:03.028228</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
<dc:title>Matplotlib v3.7.1, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
<style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
<path d="M 0 268.034375
L 350.696449 268.034375
L 350.696449 0
L 0 0
z
" style="fill: #ffffff"/>
</g>
<g id="axes_1">
<g id="patch_2">
<path d="M 7.2 244.078125
L 342 244.078125
L 342 22.318125
L 7.2 22.318125
z
" style="fill: #ffffff"/>
</g>
<g id="matplotlib.axis_1">
<g id="xtick_1">
<g id="line2d_1">
<defs>
<path id="md49eeea5b7" d="M 0 0
L 0 3.5
" style="stroke: #000000; stroke-width: 0.8"/>
</defs>
<g>
<use xlink:href="#md49eeea5b7" x="56.236364" y="244.078125" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_1">
<!-- Training Speed -->
<g transform="translate(14.12777 258.676562) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-54" d="M 31 4666
L 4331 4666
L 4331 3756
L 2784 3756
L 2784 0
L 1581 0
L 1581 3756
L 31 3756
L 31 4666
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-72" d="M 3138 2547
Q 2991 2616 2845 2648
Q 2700 2681 2553 2681
Q 2122 2681 1889 2404
Q 1656 2128 1656 1613
L 1656 0
L 538 0
L 538 3500
L 1656 3500
L 1656 2925
Q 1872 3269 2151 3426
Q 2431 3584 2822 3584
Q 2878 3584 2943 3579
Q 3009 3575 3134 3559
L 3138 2547
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-61" d="M 2106 1575
Q 1756 1575 1579 1456
Q 1403 1338 1403 1106
Q 1403 894 1545 773
Q 1688 653 1941 653
Q 2256 653 2472 879
Q 2688 1106 2688 1447
L 2688 1575
L 2106 1575
z
M 3816 1997
L 3816 0
L 2688 0
L 2688 519
Q 2463 200 2181 54
Q 1900 -91 1497 -91
Q 953 -91 614 226
Q 275 544 275 1050
Q 275 1666 698 1953
Q 1122 2241 2028 2241
L 2688 2241
L 2688 2328
Q 2688 2594 2478 2717
Q 2269 2841 1825 2841
Q 1466 2841 1156 2769
Q 847 2697 581 2553
L 581 3406
Q 941 3494 1303 3539
Q 1666 3584 2028 3584
Q 2975 3584 3395 3211
Q 3816 2838 3816 1997
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-69" d="M 538 3500
L 1656 3500
L 1656 0
L 538 0
L 538 3500
z
M 538 4863
L 1656 4863
L 1656 3950
L 538 3950
L 538 4863
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-6e" d="M 4056 2131
L 4056 0
L 2931 0
L 2931 347
L 2931 1631
Q 2931 2084 2911 2256
Q 2891 2428 2841 2509
Q 2775 2619 2662 2680
Q 2550 2741 2406 2741
Q 2056 2741 1856 2470
Q 1656 2200 1656 1722
L 1656 0
L 538 0
L 538 3500
L 1656 3500
L 1656 2988
Q 1909 3294 2193 3439
Q 2478 3584 2822 3584
Q 3428 3584 3742 3212
Q 4056 2841 4056 2131
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-67" d="M 2919 594
Q 2688 288 2409 144
Q 2131 0 1766 0
Q 1125 0 706 504
Q 288 1009 288 1791
Q 288 2575 706 3076
Q 1125 3578 1766 3578
Q 2131 3578 2409 3434
Q 2688 3291 2919 2981
L 2919 3500
L 4044 3500
L 4044 353
Q 4044 -491 3511 -936
Q 2978 -1381 1966 -1381
Q 1638 -1381 1331 -1331
Q 1025 -1281 716 -1178
L 716 -306
Q 1009 -475 1290 -558
Q 1572 -641 1856 -641
Q 2406 -641 2662 -400
Q 2919 -159 2919 353
L 2919 594
z
M 2181 2772
Q 1834 2772 1640 2515
Q 1447 2259 1447 1791
Q 1447 1309 1634 1061
Q 1822 813 2181 813
Q 2531 813 2725 1069
Q 2919 1325 2919 1791
Q 2919 2259 2725 2515
Q 2531 2772 2181 2772
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-20" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-53" d="M 3834 4519
L 3834 3531
Q 3450 3703 3084 3790
Q 2719 3878 2394 3878
Q 1963 3878 1756 3759
Q 1550 3641 1550 3391
Q 1550 3203 1689 3098
Q 1828 2994 2194 2919
L 2706 2816
Q 3484 2659 3812 2340
Q 4141 2022 4141 1434
Q 4141 663 3683 286
Q 3225 -91 2284 -91
Q 1841 -91 1394 -6
Q 947 78 500 244
L 500 1259
Q 947 1022 1364 901
Q 1781 781 2169 781
Q 2563 781 2772 912
Q 2981 1044 2981 1288
Q 2981 1506 2839 1625
Q 2697 1744 2272 1838
L 1806 1941
Q 1106 2091 782 2419
Q 459 2747 459 3303
Q 459 4000 909 4375
Q 1359 4750 2203 4750
Q 2588 4750 2994 4692
Q 3400 4634 3834 4519
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-70" d="M 1656 506
L 1656 -1331
L 538 -1331
L 538 3500
L 1656 3500
L 1656 2988
Q 1888 3294 2169 3439
Q 2450 3584 2816 3584
Q 3463 3584 3878 3070
Q 4294 2556 4294 1747
Q 4294 938 3878 423
Q 3463 -91 2816 -91
Q 2450 -91 2169 54
Q 1888 200 1656 506
z
M 2400 2772
Q 2041 2772 1848 2508
Q 1656 2244 1656 1747
Q 1656 1250 1848 986
Q 2041 722 2400 722
Q 2759 722 2948 984
Q 3138 1247 3138 1747
Q 3138 2247 2948 2509
Q 2759 2772 2400 2772
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-65" d="M 4031 1759
L 4031 1441
L 1416 1441
Q 1456 1047 1700 850
Q 1944 653 2381 653
Q 2734 653 3104 758
Q 3475 863 3866 1075
L 3866 213
Q 3469 63 3072 -14
Q 2675 -91 2278 -91
Q 1328 -91 801 392
Q 275 875 275 1747
Q 275 2603 792 3093
Q 1309 3584 2216 3584
Q 3041 3584 3536 3087
Q 4031 2591 4031 1759
z
M 2881 2131
Q 2881 2450 2695 2645
Q 2509 2841 2209 2841
Q 1884 2841 1681 2658
Q 1478 2475 1428 2131
L 2881 2131
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-64" d="M 2919 2988
L 2919 4863
L 4044 4863
L 4044 0
L 2919 0
L 2919 506
Q 2688 197 2409 53
Q 2131 -91 1766 -91
Q 1119 -91 703 423
Q 288 938 288 1747
Q 288 2556 703 3070
Q 1119 3584 1766 3584
Q 2128 3584 2408 3439
Q 2688 3294 2919 2988
z
M 2181 722
Q 2541 722 2730 984
Q 2919 1247 2919 1747
Q 2919 2247 2730 2509
Q 2541 2772 2181 2772
Q 1825 2772 1636 2509
Q 1447 2247 1447 1747
Q 1447 1247 1636 984
Q 1825 722 2181 722
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-54"/>
<use xlink:href="#DejaVuSans-Bold-72" x="57.212891"/>
<use xlink:href="#DejaVuSans-Bold-61" x="106.529297"/>
<use xlink:href="#DejaVuSans-Bold-69" x="174.009766"/>
<use xlink:href="#DejaVuSans-Bold-6e" x="208.287109"/>
<use xlink:href="#DejaVuSans-Bold-69" x="279.478516"/>
<use xlink:href="#DejaVuSans-Bold-6e" x="313.755859"/>
<use xlink:href="#DejaVuSans-Bold-67" x="384.947266"/>
<use xlink:href="#DejaVuSans-Bold-20" x="456.529297"/>
<use xlink:href="#DejaVuSans-Bold-53" x="491.34375"/>
<use xlink:href="#DejaVuSans-Bold-70" x="563.365234"/>
<use xlink:href="#DejaVuSans-Bold-65" x="634.947266"/>
<use xlink:href="#DejaVuSans-Bold-65" x="702.769531"/>
<use xlink:href="#DejaVuSans-Bold-64" x="770.591797"/>
</g>
</g>
</g>
<g id="xtick_2">
<g id="line2d_2">
<g>
<use xlink:href="#md49eeea5b7" x="174.6" y="244.078125" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_2">
<!-- Rouge Score -->
<g transform="translate(139.1875 258.598437) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-52" d="M 2297 2597
Q 2675 2597 2839 2737
Q 3003 2878 3003 3200
Q 3003 3519 2839 3656
Q 2675 3794 2297 3794
L 1791 3794
L 1791 2597
L 2297 2597
z
M 1791 1766
L 1791 0
L 588 0
L 588 4666
L 2425 4666
Q 3347 4666 3776 4356
Q 4206 4047 4206 3378
Q 4206 2916 3982 2619
Q 3759 2322 3309 2181
Q 3556 2125 3751 1926
Q 3947 1728 4147 1325
L 4800 0
L 3519 0
L 2950 1159
Q 2778 1509 2601 1637
Q 2425 1766 2131 1766
L 1791 1766
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-6f" d="M 2203 2784
Q 1831 2784 1636 2517
Q 1441 2250 1441 1747
Q 1441 1244 1636 976
Q 1831 709 2203 709
Q 2569 709 2762 976
Q 2956 1244 2956 1747
Q 2956 2250 2762 2517
Q 2569 2784 2203 2784
z
M 2203 3584
Q 3106 3584 3614 3096
Q 4122 2609 4122 1747
Q 4122 884 3614 396
Q 3106 -91 2203 -91
Q 1297 -91 786 396
Q 275 884 275 1747
Q 275 2609 786 3096
Q 1297 3584 2203 3584
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-75" d="M 500 1363
L 500 3500
L 1625 3500
L 1625 3150
Q 1625 2866 1622 2436
Q 1619 2006 1619 1863
Q 1619 1441 1641 1255
Q 1663 1069 1716 984
Q 1784 875 1895 815
Q 2006 756 2150 756
Q 2500 756 2700 1025
Q 2900 1294 2900 1772
L 2900 3500
L 4019 3500
L 4019 0
L 2900 0
L 2900 506
Q 2647 200 2364 54
Q 2081 -91 1741 -91
Q 1134 -91 817 281
Q 500 653 500 1363
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-63" d="M 3366 3391
L 3366 2478
Q 3138 2634 2908 2709
Q 2678 2784 2431 2784
Q 1963 2784 1702 2511
Q 1441 2238 1441 1747
Q 1441 1256 1702 982
Q 1963 709 2431 709
Q 2694 709 2930 787
Q 3166 866 3366 1019
L 3366 103
Q 3103 6 2833 -42
Q 2563 -91 2291 -91
Q 1344 -91 809 395
Q 275 881 275 1747
Q 275 2613 809 3098
Q 1344 3584 2291 3584
Q 2566 3584 2833 3536
Q 3100 3488 3366 3391
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-52"/>
<use xlink:href="#DejaVuSans-Bold-6f" x="77.001953"/>
<use xlink:href="#DejaVuSans-Bold-75" x="145.703125"/>
<use xlink:href="#DejaVuSans-Bold-67" x="216.894531"/>
<use xlink:href="#DejaVuSans-Bold-65" x="288.476562"/>
<use xlink:href="#DejaVuSans-Bold-20" x="356.298828"/>
<use xlink:href="#DejaVuSans-Bold-53" x="391.113281"/>
<use xlink:href="#DejaVuSans-Bold-63" x="463.134766"/>
<use xlink:href="#DejaVuSans-Bold-6f" x="522.412109"/>
<use xlink:href="#DejaVuSans-Bold-72" x="591.113281"/>
<use xlink:href="#DejaVuSans-Bold-65" x="640.429688"/>
</g>
</g>
</g>
<g id="xtick_3">
<g id="line2d_3">
<g>
<use xlink:href="#md49eeea5b7" x="292.963636" y="244.078125" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_3">
<!-- GPU Memory (GB) -->
<g transform="translate(242.430824 258.665625) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-47" d="M 4781 347
Q 4331 128 3847 18
Q 3363 -91 2847 -91
Q 1681 -91 1000 561
Q 319 1213 319 2328
Q 319 3456 1012 4103
Q 1706 4750 2913 4750
Q 3378 4750 3804 4662
Q 4231 4575 4609 4403
L 4609 3438
Q 4219 3659 3833 3768
Q 3447 3878 3059 3878
Q 2341 3878 1952 3476
Q 1563 3075 1563 2328
Q 1563 1588 1938 1184
Q 2313 781 3003 781
Q 3191 781 3352 804
Q 3513 828 3641 878
L 3641 1784
L 2906 1784
L 2906 2591
L 4781 2591
L 4781 347
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-50" d="M 588 4666
L 2584 4666
Q 3475 4666 3951 4270
Q 4428 3875 4428 3144
Q 4428 2409 3951 2014
Q 3475 1619 2584 1619
L 1791 1619
L 1791 0
L 588 0
L 588 4666
z
M 1791 3794
L 1791 2491
L 2456 2491
Q 2806 2491 2997 2661
Q 3188 2831 3188 3144
Q 3188 3456 2997 3625
Q 2806 3794 2456 3794
L 1791 3794
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-55" d="M 588 4666
L 1791 4666
L 1791 1869
Q 1791 1291 1980 1042
Q 2169 794 2597 794
Q 3028 794 3217 1042
Q 3406 1291 3406 1869
L 3406 4666
L 4609 4666
L 4609 1869
Q 4609 878 4112 393
Q 3616 -91 2597 -91
Q 1581 -91 1084 393
Q 588 878 588 1869
L 588 4666
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-4d" d="M 588 4666
L 2119 4666
L 3181 2169
L 4250 4666
L 5778 4666
L 5778 0
L 4641 0
L 4641 3413
L 3566 897
L 2803 897
L 1728 3413
L 1728 0
L 588 0
L 588 4666
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-6d" d="M 3781 2919
Q 3994 3244 4286 3414
Q 4578 3584 4928 3584
Q 5531 3584 5847 3212
Q 6163 2841 6163 2131
L 6163 0
L 5038 0
L 5038 1825
Q 5041 1866 5042 1909
Q 5044 1953 5044 2034
Q 5044 2406 4934 2573
Q 4825 2741 4581 2741
Q 4263 2741 4089 2478
Q 3916 2216 3909 1719
L 3909 0
L 2784 0
L 2784 1825
Q 2784 2406 2684 2573
Q 2584 2741 2328 2741
Q 2006 2741 1831 2477
Q 1656 2213 1656 1722
L 1656 0
L 531 0
L 531 3500
L 1656 3500
L 1656 2988
Q 1863 3284 2130 3434
Q 2397 3584 2719 3584
Q 3081 3584 3359 3409
Q 3638 3234 3781 2919
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-79" d="M 78 3500
L 1197 3500
L 2138 1125
L 2938 3500
L 4056 3500
L 2584 -331
Q 2363 -916 2067 -1148
Q 1772 -1381 1288 -1381
L 641 -1381
L 641 -647
L 991 -647
Q 1275 -647 1404 -556
Q 1534 -466 1606 -231
L 1638 -134
L 78 3500
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-28" d="M 2413 -844
L 1484 -844
Q 1006 -72 778 623
Q 550 1319 550 2003
Q 550 2688 779 3389
Q 1009 4091 1484 4856
L 2413 4856
Q 2013 4116 1813 3408
Q 1613 2700 1613 2009
Q 1613 1319 1811 609
Q 2009 -100 2413 -844
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-42" d="M 2456 2859
Q 2741 2859 2887 2984
Q 3034 3109 3034 3353
Q 3034 3594 2887 3720
Q 2741 3847 2456 3847
L 1791 3847
L 1791 2859
L 2456 2859
z
M 2497 819
Q 2859 819 3042 972
Q 3225 1125 3225 1434
Q 3225 1738 3044 1889
Q 2863 2041 2497 2041
L 1791 2041
L 1791 819
L 2497 819
z
M 3616 2497
Q 4003 2384 4215 2081
Q 4428 1778 4428 1338
Q 4428 663 3972 331
Q 3516 0 2584 0
L 588 0
L 588 4666
L 2394 4666
Q 3366 4666 3802 4372
Q 4238 4078 4238 3431
Q 4238 3091 4078 2852
Q 3919 2613 3616 2497
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-29" d="M 513 -844
Q 913 -100 1113 609
Q 1313 1319 1313 2009
Q 1313 2700 1113 3408
Q 913 4116 513 4856
L 1441 4856
Q 1916 4091 2145 3389
Q 2375 2688 2375 2003
Q 2375 1319 2147 623
Q 1919 -72 1441 -844
L 513 -844
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-47"/>
<use xlink:href="#DejaVuSans-Bold-50" x="82.080078"/>
<use xlink:href="#DejaVuSans-Bold-55" x="155.371094"/>
<use xlink:href="#DejaVuSans-Bold-20" x="236.572266"/>
<use xlink:href="#DejaVuSans-Bold-4d" x="271.386719"/>
<use xlink:href="#DejaVuSans-Bold-65" x="370.898438"/>
<use xlink:href="#DejaVuSans-Bold-6d" x="438.720703"/>
<use xlink:href="#DejaVuSans-Bold-6f" x="542.919922"/>
<use xlink:href="#DejaVuSans-Bold-72" x="611.621094"/>
<use xlink:href="#DejaVuSans-Bold-79" x="660.9375"/>
<use xlink:href="#DejaVuSans-Bold-20" x="726.123047"/>
<use xlink:href="#DejaVuSans-Bold-28" x="760.9375"/>
<use xlink:href="#DejaVuSans-Bold-47" x="806.640625"/>
<use xlink:href="#DejaVuSans-Bold-42" x="888.720703"/>
<use xlink:href="#DejaVuSans-Bold-29" x="964.941406"/>
</g>
</g>
</g>
</g>
<g id="patch_3">
<path d="M 22.418182 244.078125
L 56.236364 244.078125
L 56.236364 195.339663
L 22.418182 195.339663
z
" clip-path="url(#p080f205d85)" style="fill: #6baed6"/>
</g>
<g id="patch_4">
<path d="M 140.781818 244.078125
L 174.6 244.078125
L 174.6 146.601202
L 140.781818 146.601202
z
" clip-path="url(#p080f205d85)" style="fill: #6baed6"/>
</g>
<g id="patch_5">
<path d="M 259.145455 244.078125
L 292.963636 244.078125
L 292.963636 205.087356
L 259.145455 205.087356
z
" clip-path="url(#p080f205d85)" style="fill: #6baed6"/>
</g>
<g id="patch_6">
<path d="M 56.236364 244.078125
L 90.054545 244.078125
L 90.054545 32.878125
L 56.236364 32.878125
z
" clip-path="url(#p080f205d85)" style="fill: #3182bd"/>
</g>
<g id="patch_7">
<path d="M 174.6 244.078125
L 208.418182 244.078125
L 208.418182 130.355048
L 174.6 130.355048
z
" clip-path="url(#p080f205d85)" style="fill: #3182bd"/>
</g>
<g id="patch_8">
<path d="M 292.963636 244.078125
L 326.781818 244.078125
L 326.781818 218.084279
L 292.963636 218.084279
z
" clip-path="url(#p080f205d85)" style="fill: #3182bd"/>
</g>
<g id="patch_9">
<path d="M 7.2 244.078125
L 342 244.078125
" style="fill: none; stroke: #dddddd; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/>
</g>
<g id="text_4">
<!-- 5.81 -->
<g transform="translate(26.991335 193.259976) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-35" d="M 678 4666
L 3669 4666
L 3669 3781
L 1638 3781
L 1638 3059
Q 1775 3097 1914 3117
Q 2053 3138 2203 3138
Q 3056 3138 3531 2711
Q 4006 2284 4006 1522
Q 4006 766 3489 337
Q 2972 -91 2053 -91
Q 1656 -91 1267 -14
Q 878 63 494 219
L 494 1166
Q 875 947 1217 837
Q 1559 728 1863 728
Q 2300 728 2551 942
Q 2803 1156 2803 1522
Q 2803 1891 2551 2103
Q 2300 2316 1863 2316
Q 1603 2316 1309 2248
Q 1016 2181 678 2041
L 678 4666
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-2e" d="M 653 1209
L 1778 1209
L 1778 0
L 653 0
L 653 1209
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-38" d="M 2228 2088
Q 1891 2088 1709 1903
Q 1528 1719 1528 1375
Q 1528 1031 1709 848
Q 1891 666 2228 666
Q 2563 666 2741 848
Q 2919 1031 2919 1375
Q 2919 1722 2741 1905
Q 2563 2088 2228 2088
z
M 1350 2484
Q 925 2613 709 2878
Q 494 3144 494 3541
Q 494 4131 934 4440
Q 1375 4750 2228 4750
Q 3075 4750 3515 4442
Q 3956 4134 3956 3541
Q 3956 3144 3739 2878
Q 3522 2613 3097 2484
Q 3572 2353 3814 2058
Q 4056 1763 4056 1313
Q 4056 619 3595 264
Q 3134 -91 2228 -91
Q 1319 -91 855 264
Q 391 619 391 1313
Q 391 1763 633 2058
Q 875 2353 1350 2484
z
M 1631 3419
Q 1631 3141 1786 2991
Q 1941 2841 2228 2841
Q 2509 2841 2662 2991
Q 2816 3141 2816 3419
Q 2816 3697 2662 3845
Q 2509 3994 2228 3994
Q 1941 3994 1786 3844
Q 1631 3694 1631 3419
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-31" d="M 750 831
L 1813 831
L 1813 3847
L 722 3622
L 722 4441
L 1806 4666
L 2950 4666
L 2950 831
L 4013 831
L 4013 0
L 750 0
L 750 831
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-35"/>
<use xlink:href="#DejaVuSans-Bold-2e" x="69.580078"/>
<use xlink:href="#DejaVuSans-Bold-38" x="107.568359"/>
<use xlink:href="#DejaVuSans-Bold-31" x="177.148438"/>
</g>
</g>
<g id="text_5">
<!-- 7.20 -->
<g transform="translate(145.354972 144.521514) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-37" d="M 428 4666
L 3944 4666
L 3944 3988
L 2125 0
L 953 0
L 2675 3781
L 428 3781
L 428 4666
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-32" d="M 1844 884
L 3897 884
L 3897 0
L 506 0
L 506 884
L 2209 2388
Q 2438 2594 2547 2791
Q 2656 2988 2656 3200
Q 2656 3528 2436 3728
Q 2216 3928 1850 3928
Q 1569 3928 1234 3808
Q 900 3688 519 3450
L 519 4475
Q 925 4609 1322 4679
Q 1719 4750 2100 4750
Q 2938 4750 3402 4381
Q 3866 4013 3866 3353
Q 3866 2972 3669 2642
Q 3472 2313 2841 1759
L 1844 884
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-30" d="M 2944 2338
Q 2944 3213 2780 3570
Q 2616 3928 2228 3928
Q 1841 3928 1675 3570
Q 1509 3213 1509 2338
Q 1509 1453 1675 1090
Q 1841 728 2228 728
Q 2613 728 2778 1090
Q 2944 1453 2944 2338
z
M 4147 2328
Q 4147 1169 3647 539
Q 3147 -91 2228 -91
Q 1306 -91 806 539
Q 306 1169 306 2328
Q 306 3491 806 4120
Q 1306 4750 2228 4750
Q 3147 4750 3647 4120
Q 4147 3491 4147 2328
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-37"/>
<use xlink:href="#DejaVuSans-Bold-2e" x="69.580078"/>
<use xlink:href="#DejaVuSans-Bold-32" x="107.568359"/>
<use xlink:href="#DejaVuSans-Bold-30" x="177.148438"/>
</g>
</g>
<g id="text_6">
<!-- 5.78 -->
<g transform="translate(263.718608 203.007668) scale(0.1 -0.1)">
<use xlink:href="#DejaVuSans-Bold-35"/>
<use xlink:href="#DejaVuSans-Bold-2e" x="69.580078"/>
<use xlink:href="#DejaVuSans-Bold-37" x="107.568359"/>
<use xlink:href="#DejaVuSans-Bold-38" x="177.148438"/>
</g>
</g>
<g id="text_7">
<!-- 21.67 -->
<g transform="translate(57.330611 30.798438) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-36" d="M 2316 2303
Q 2000 2303 1842 2098
Q 1684 1894 1684 1484
Q 1684 1075 1842 870
Q 2000 666 2316 666
Q 2634 666 2792 870
Q 2950 1075 2950 1484
Q 2950 1894 2792 2098
Q 2634 2303 2316 2303
z
M 3803 4544
L 3803 3681
Q 3506 3822 3243 3889
Q 2981 3956 2731 3956
Q 2194 3956 1894 3657
Q 1594 3359 1544 2772
Q 1750 2925 1990 3001
Q 2231 3078 2516 3078
Q 3231 3078 3670 2659
Q 4109 2241 4109 1563
Q 4109 813 3618 361
Q 3128 -91 2303 -91
Q 1394 -91 895 523
Q 397 1138 397 2266
Q 397 3422 980 4083
Q 1563 4744 2578 4744
Q 2900 4744 3203 4694
Q 3506 4644 3803 4544
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-32"/>
<use xlink:href="#DejaVuSans-Bold-31" x="69.580078"/>
<use xlink:href="#DejaVuSans-Bold-2e" x="139.160156"/>
<use xlink:href="#DejaVuSans-Bold-36" x="177.148438"/>
<use xlink:href="#DejaVuSans-Bold-37" x="246.728516"/>
</g>
</g>
<g id="text_8">
<!-- 7.36 -->
<g transform="translate(179.173153 128.275361) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-33" d="M 2981 2516
Q 3453 2394 3698 2092
Q 3944 1791 3944 1325
Q 3944 631 3412 270
Q 2881 -91 1863 -91
Q 1503 -91 1142 -33
Q 781 25 428 141
L 428 1069
Q 766 900 1098 814
Q 1431 728 1753 728
Q 2231 728 2486 893
Q 2741 1059 2741 1369
Q 2741 1688 2480 1852
Q 2219 2016 1709 2016
L 1228 2016
L 1228 2791
L 1734 2791
Q 2188 2791 2409 2933
Q 2631 3075 2631 3366
Q 2631 3634 2415 3781
Q 2200 3928 1806 3928
Q 1516 3928 1219 3862
Q 922 3797 628 3669
L 628 4550
Q 984 4650 1334 4700
Q 1684 4750 2022 4750
Q 2931 4750 3382 4451
Q 3834 4153 3834 3553
Q 3834 3144 3618 2883
Q 3403 2622 2981 2516
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-37"/>
<use xlink:href="#DejaVuSans-Bold-2e" x="69.580078"/>
<use xlink:href="#DejaVuSans-Bold-33" x="107.568359"/>
<use xlink:href="#DejaVuSans-Bold-36" x="177.148438"/>
</g>
</g>
<g id="text_9">
<!-- 5.14 -->
<g transform="translate(297.53679 216.004591) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-34" d="M 2356 3675
L 1038 1722
L 2356 1722
L 2356 3675
z
M 2156 4666
L 3494 4666
L 3494 1722
L 4159 1722
L 4159 850
L 3494 850
L 3494 0
L 2356 0
L 2356 850
L 288 850
L 288 1881
L 2156 4666
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-35"/>
<use xlink:href="#DejaVuSans-Bold-2e" x="69.580078"/>
<use xlink:href="#DejaVuSans-Bold-31" x="107.568359"/>
<use xlink:href="#DejaVuSans-Bold-34" x="177.148438"/>
</g>
</g>
<g id="text_10">
<!-- ChatGLM2-6B - - 1×A100 -->
<g transform="translate(93.349688 16.318125) scale(0.12 -0.12)">
<defs>
<path id="DejaVuSans-Bold-43" d="M 4288 256
Q 3956 84 3597 -3
Q 3238 -91 2847 -91
Q 1681 -91 1000 561
Q 319 1213 319 2328
Q 319 3447 1000 4098
Q 1681 4750 2847 4750
Q 3238 4750 3597 4662
Q 3956 4575 4288 4403
L 4288 3438
Q 3953 3666 3628 3772
Q 3303 3878 2944 3878
Q 2300 3878 1931 3465
Q 1563 3053 1563 2328
Q 1563 1606 1931 1193
Q 2300 781 2944 781
Q 3303 781 3628 887
Q 3953 994 4288 1222
L 4288 256
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-68" d="M 4056 2131
L 4056 0
L 2931 0
L 2931 347
L 2931 1625
Q 2931 2084 2911 2256
Q 2891 2428 2841 2509
Q 2775 2619 2662 2680
Q 2550 2741 2406 2741
Q 2056 2741 1856 2470
Q 1656 2200 1656 1722
L 1656 0
L 538 0
L 538 4863
L 1656 4863
L 1656 2988
Q 1909 3294 2193 3439
Q 2478 3584 2822 3584
Q 3428 3584 3742 3212
Q 4056 2841 4056 2131
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-74" d="M 1759 4494
L 1759 3500
L 2913 3500
L 2913 2700
L 1759 2700
L 1759 1216
Q 1759 972 1856 886
Q 1953 800 2241 800
L 2816 800
L 2816 0
L 1856 0
Q 1194 0 917 276
Q 641 553 641 1216
L 641 2700
L 84 2700
L 84 3500
L 641 3500
L 641 4494
L 1759 4494
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-4c" d="M 588 4666
L 1791 4666
L 1791 909
L 3903 909
L 3903 0
L 588 0
L 588 4666
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-2d" d="M 347 2297
L 2309 2297
L 2309 1388
L 347 1388
L 347 2297
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-d7" d="M 4563 3359
L 3206 2003
L 4563 653
L 4038 128
L 2681 1478
L 1325 128
L 800 653
L 2156 2003
L 800 3359
L 1325 3884
L 2681 2528
L 4038 3884
L 4563 3359
z
" transform="scale(0.015625)"/>
<path id="DejaVuSans-Bold-41" d="M 3419 850
L 1538 850
L 1241 0
L 31 0
L 1759 4666
L 3194 4666
L 4922 0
L 3713 0
L 3419 850
z
M 1838 1716
L 3116 1716
L 2478 3572
L 1838 1716
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-43"/>
<use xlink:href="#DejaVuSans-Bold-68" x="73.388672"/>
<use xlink:href="#DejaVuSans-Bold-61" x="144.580078"/>
<use xlink:href="#DejaVuSans-Bold-74" x="212.060547"/>
<use xlink:href="#DejaVuSans-Bold-47" x="259.863281"/>
<use xlink:href="#DejaVuSans-Bold-4c" x="341.943359"/>
<use xlink:href="#DejaVuSans-Bold-4d" x="405.664062"/>
<use xlink:href="#DejaVuSans-Bold-32" x="505.175781"/>
<use xlink:href="#DejaVuSans-Bold-2d" x="574.755859"/>
<use xlink:href="#DejaVuSans-Bold-36" x="616.259766"/>
<use xlink:href="#DejaVuSans-Bold-42" x="685.839844"/>
<use xlink:href="#DejaVuSans-Bold-20" x="762.060547"/>
<use xlink:href="#DejaVuSans-Bold-2d" x="796.875"/>
<use xlink:href="#DejaVuSans-Bold-2d" x="838.378906"/>
<use xlink:href="#DejaVuSans-Bold-20" x="879.882812"/>
<use xlink:href="#DejaVuSans-Bold-31" x="914.697266"/>
<use xlink:href="#DejaVuSans-Bold-d7" x="984.277344"/>
<use xlink:href="#DejaVuSans-Bold-41" x="1068.066406"/>
<use xlink:href="#DejaVuSans-Bold-31" x="1145.458984"/>
<use xlink:href="#DejaVuSans-Bold-30" x="1215.039062"/>
<use xlink:href="#DejaVuSans-Bold-30" x="1284.619141"/>
</g>
</g>
<g id="legend_1">
<g id="patch_10">
<path d="M 201.507812 59.830625
L 335 59.830625
Q 337 59.830625 337 57.830625
L 337 29.318125
Q 337 27.318125 335 27.318125
L 201.507812 27.318125
Q 199.507812 27.318125 199.507812 29.318125
L 199.507812 57.830625
Q 199.507812 59.830625 201.507812 59.830625
L 201.507812 59.830625
z
" style="fill: none; opacity: 0"/>
</g>
<g id="patch_11">
<path d="M 203.507812 38.916562
L 223.507812 38.916562
L 223.507812 31.916562
L 203.507812 31.916562
z
" style="fill: #6baed6"/>
</g>
<g id="text_11">
<!-- ChatGLM P-Tuning -->
<g transform="translate(231.507812 38.916562) scale(0.1 -0.1)">
<use xlink:href="#DejaVuSans-Bold-43"/>
<use xlink:href="#DejaVuSans-Bold-68" x="73.388672"/>
<use xlink:href="#DejaVuSans-Bold-61" x="144.580078"/>
<use xlink:href="#DejaVuSans-Bold-74" x="212.060547"/>
<use xlink:href="#DejaVuSans-Bold-47" x="259.863281"/>
<use xlink:href="#DejaVuSans-Bold-4c" x="341.943359"/>
<use xlink:href="#DejaVuSans-Bold-4d" x="405.664062"/>
<use xlink:href="#DejaVuSans-Bold-20" x="505.175781"/>
<use xlink:href="#DejaVuSans-Bold-50" x="539.990234"/>
<use xlink:href="#DejaVuSans-Bold-2d" x="611.53125"/>
<use xlink:href="#DejaVuSans-Bold-54" x="638.285156"/>
<use xlink:href="#DejaVuSans-Bold-75" x="695.498047"/>
<use xlink:href="#DejaVuSans-Bold-6e" x="766.689453"/>
<use xlink:href="#DejaVuSans-Bold-69" x="837.880859"/>
<use xlink:href="#DejaVuSans-Bold-6e" x="872.158203"/>
<use xlink:href="#DejaVuSans-Bold-67" x="943.349609"/>
</g>
</g>
<g id="patch_12">
<path d="M 203.507812 53.672812
L 223.507812 53.672812
L 223.507812 46.672812
L 203.507812 46.672812
z
" style="fill: #3182bd"/>
</g>
<g id="text_12">
<!-- LLaMA-Factory -->
<g transform="translate(231.507812 53.672812) scale(0.1 -0.1)">
<defs>
<path id="DejaVuSans-Bold-46" d="M 588 4666
L 3834 4666
L 3834 3756
L 1791 3756
L 1791 2888
L 3713 2888
L 3713 1978
L 1791 1978
L 1791 0
L 588 0
L 588 4666
z
" transform="scale(0.015625)"/>
</defs>
<use xlink:href="#DejaVuSans-Bold-4c"/>
<use xlink:href="#DejaVuSans-Bold-4c" x="63.720703"/>
<use xlink:href="#DejaVuSans-Bold-61" x="127.441406"/>
<use xlink:href="#DejaVuSans-Bold-4d" x="194.921875"/>
<use xlink:href="#DejaVuSans-Bold-41" x="294.433594"/>
<use xlink:href="#DejaVuSans-Bold-2d" x="371.826172"/>
<use xlink:href="#DejaVuSans-Bold-46" x="413.330078"/>
<use xlink:href="#DejaVuSans-Bold-61" x="475.765625"/>
<use xlink:href="#DejaVuSans-Bold-63" x="543.246094"/>
<use xlink:href="#DejaVuSans-Bold-74" x="602.523438"/>
<use xlink:href="#DejaVuSans-Bold-6f" x="650.326172"/>
<use xlink:href="#DejaVuSans-Bold-72" x="719.027344"/>
<use xlink:href="#DejaVuSans-Bold-79" x="768.34375"/>
</g>
</g>
</g>
</g>
</g>
<defs>
<clipPath id="p080f205d85">
<rect x="7.2" y="22.318125" width="334.8" height="221.76"/>
</clipPath>
</defs>
</svg>
assets/wechat.jpg

161 KB | W: | H:

assets/wechat.jpg

167 KB | W: | H:

assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
  • 2-up
  • Swipe
  • Onion skin
assets/wechat_npu.jpg

168 KB | W: | H:

assets/wechat_npu.jpg

167 KB | W: | H:

assets/wechat_npu.jpg
assets/wechat_npu.jpg
assets/wechat_npu.jpg
assets/wechat_npu.jpg
  • 2-up
  • Swipe
  • Onion skin
The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it. The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it.
The `dataset_info.json` file should be put in the `dataset_dir` directory. You can change `dataset_dir` to use another directory. The default value is `./data`.
Currently we support datasets in **alpaca** and **sharegpt** format. Currently we support datasets in **alpaca** and **sharegpt** format.
```json ```json
...@@ -48,7 +50,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format. ...@@ -48,7 +50,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
* [Example dataset](alpaca_en_demo.json) * [Example dataset](alpaca_en_demo.json)
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response. In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the user prompt, then the user prompt would be `instruction\ninput`. The `output` column represents the model response.
For reasoning models, if the dataset contains chain-of-thought (CoT), the CoT needs to be placed in the model responses, such as `<think>cot</think>output`.
The `system` column will be used as the system prompt if specified. The `system` column will be used as the system prompt if specified.
...@@ -57,13 +61,13 @@ The `history` column is a list consisting of string tuples representing prompt-r ...@@ -57,13 +61,13 @@ The `history` column is a list consisting of string tuples representing prompt-r
```json ```json
[ [
{ {
"instruction": "human instruction (required)", "instruction": "user instruction (required)",
"input": "human input (optional)", "input": "user input (optional)",
"output": "model response (required)", "output": "model response (required)",
"system": "system prompt (optional)", "system": "system prompt (optional)",
"history": [ "history": [
["human instruction in the first round (optional)", "model response in the first round (optional)"], ["user instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"] ["user instruction in the second round (optional)", "model response in the second round (optional)"]
] ]
} }
] ]
...@@ -84,6 +88,11 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh ...@@ -84,6 +88,11 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
> [!TIP]
> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True` (slow thinking), the empty CoT will be added to the model responses and loss computation will be considered; otherwise (fast thinking), it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
>
> If you want to train data containing CoT with slow thinking and data without CoT with fast thinking, you can set `enable_thinking` to `None`. However, this feature is relatively complicated and should be used with caution.
### Pre-training Dataset ### Pre-training Dataset
- [Example dataset](c4_demo.jsonl) - [Example dataset](c4_demo.jsonl)
...@@ -117,8 +126,8 @@ It requires a better response in `chosen` column and a worse response in `reject ...@@ -117,8 +126,8 @@ It requires a better response in `chosen` column and a worse response in `reject
```json ```json
[ [
{ {
"instruction": "human instruction (required)", "instruction": "user instruction (required)",
"input": "human input (optional)", "input": "user input (optional)",
"chosen": "chosen answer (required)", "chosen": "chosen answer (required)",
"rejected": "rejected answer (required)" "rejected": "rejected answer (required)"
} }
...@@ -172,7 +181,7 @@ Note that the human and observation should appear in odd positions, while gpt an ...@@ -172,7 +181,7 @@ Note that the human and observation should appear in odd positions, while gpt an
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "function_call", "from": "function_call",
...@@ -223,7 +232,7 @@ Preference datasets in sharegpt format also require a better message in `chosen` ...@@ -223,7 +232,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -231,7 +240,7 @@ Preference datasets in sharegpt format also require a better message in `chosen` ...@@ -231,7 +240,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
}, },
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
} }
], ],
"chosen": { "chosen": {
...@@ -273,7 +282,7 @@ KTO datasets require a extra `kto_tag` column containing the boolean human feedb ...@@ -273,7 +282,7 @@ KTO datasets require a extra `kto_tag` column containing the boolean human feedb
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -312,7 +321,7 @@ The number of images should be identical to the `<image>` tokens in the conversa ...@@ -312,7 +321,7 @@ The number of images should be identical to the `<image>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<image>human instruction" "value": "<image>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -353,7 +362,7 @@ The number of videos should be identical to the `<video>` tokens in the conversa ...@@ -353,7 +362,7 @@ The number of videos should be identical to the `<video>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<video>human instruction" "value": "<video>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -394,7 +403,7 @@ The number of audios should be identical to the `<audio>` tokens in the conversa ...@@ -394,7 +403,7 @@ The number of audios should be identical to the `<audio>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<audio>human instruction" "value": "<audio>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -435,7 +444,7 @@ The openai format is simply a special case of the sharegpt format, where the fir ...@@ -435,7 +444,7 @@ The openai format is simply a special case of the sharegpt format, where the fir
}, },
{ {
"role": "user", "role": "user",
"content": "human instruction" "content": "user instruction"
}, },
{ {
"role": "assistant", "role": "assistant",
......
[dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**`dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集。 [dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**`dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集。
其中 `dataset_info.json` 文件应放置在 `dataset_dir` 目录下。您可以通过修改 `dataset_dir` 参数来使用其他目录。默认值为 `./data`
目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集。 目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集。
```json ```json
...@@ -47,7 +49,9 @@ ...@@ -47,7 +49,9 @@
- [样例数据集](alpaca_zh_demo.json) - [样例数据集](alpaca_zh_demo.json)
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。 在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为提示词,即提示词为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
对于推理类模型的微调,如果数据集包含思维链,则需要把思维链放在模型回答中,例如 `<think>cot</think>output`
如果指定,`system` 列对应的内容将被作为系统提示词。 如果指定,`system` 列对应的内容将被作为系统提示词。
...@@ -56,8 +60,8 @@ ...@@ -56,8 +60,8 @@
```json ```json
[ [
{ {
"instruction": "人类指令(必填)", "instruction": "用户指令(必填)",
"input": "人类输入(选填)", "input": "用户输入(选填)",
"output": "模型回答(必填)", "output": "模型回答(必填)",
"system": "系统提示词(选填)", "system": "系统提示词(选填)",
"history": [ "history": [
...@@ -83,6 +87,11 @@ ...@@ -83,6 +87,11 @@
} }
``` ```
> [!TIP]
> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时(慢思考),空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失(快思考)。请在训练和推理时保持 `enable_thinking` 参数一致。
>
> 如果您希望训练包含思维链的数据时使用慢思考,训练不包含思维链的数据时使用快思考,可以设置 `enable_thinking` 为 `None`。但该功能较为复杂,请谨慎使用。
### 预训练数据集 ### 预训练数据集
- [样例数据集](c4_demo.jsonl) - [样例数据集](c4_demo.jsonl)
...@@ -116,8 +125,8 @@ ...@@ -116,8 +125,8 @@
```json ```json
[ [
{ {
"instruction": "人类指令(必填)", "instruction": "用户指令(必填)",
"input": "人类输入(选填)", "input": "用户输入(选填)",
"chosen": "优质回答(必填)", "chosen": "优质回答(必填)",
"rejected": "劣质回答(必填)" "rejected": "劣质回答(必填)"
} }
...@@ -171,7 +180,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s ...@@ -171,7 +180,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "function_call", "from": "function_call",
...@@ -222,7 +231,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的 ...@@ -222,7 +231,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -230,7 +239,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的 ...@@ -230,7 +239,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
}, },
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
} }
], ],
"chosen": { "chosen": {
...@@ -272,7 +281,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -272,7 +281,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -311,7 +320,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -311,7 +320,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<image>人类指令" "value": "<image><image>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -319,6 +328,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -319,6 +328,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"images": [ "images": [
"图像路径(必填)",
"图像路径(必填)" "图像路径(必填)"
] ]
} }
...@@ -352,7 +362,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -352,7 +362,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<video>人类指令" "value": "<video><video>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -360,6 +370,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -360,6 +370,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"videos": [ "videos": [
"视频路径(必填)",
"视频路径(必填)" "视频路径(必填)"
] ]
} }
...@@ -393,7 +404,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -393,7 +404,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<audio>人类指令" "value": "<audio><audio>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
...@@ -401,6 +412,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 ...@@ -401,6 +412,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"audios": [ "audios": [
"音频路径(必填)",
"音频路径(必填)" "音频路径(必填)"
] ]
} }
...@@ -435,7 +447,7 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消 ...@@ -435,7 +447,7 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
}, },
{ {
"role": "user", "role": "user",
"content": "人类指令" "content": "用户指令"
}, },
{ {
"role": "assistant", "role": "assistant",
......
...@@ -559,6 +559,16 @@ ...@@ -559,6 +559,16 @@
"images": "images" "images": "images"
} }
}, },
"rlaif_v": {
"hf_hub_url": "openbmb/RLAIF-V-Dataset",
"ranking": true,
"columns": {
"prompt": "question",
"chosen": "chosen",
"rejected": "rejected",
"images": "image"
}
},
"orca_pairs": { "orca_pairs": {
"hf_hub_url": "Intel/orca_dpo_pairs", "hf_hub_url": "Intel/orca_dpo_pairs",
"ranking": true, "ranking": true,
......
...@@ -52,7 +52,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ...@@ -52,7 +52,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO Training #### DPO/ORPO/SimPO Training
...@@ -64,7 +64,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml ...@@ -64,7 +64,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
#### Multimodal DPO/ORPO/SimPO Training #### Multimodal DPO/ORPO/SimPO Training
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
``` ```
#### Reward Modeling #### Reward Modeling
...@@ -168,7 +168,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 ...@@ -168,7 +168,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
...@@ -195,10 +195,11 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml ...@@ -195,10 +195,11 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
#### Batch Generation using vLLM Tensor Parallel #### Evaluation using vLLM's Multi-GPU Inference
``` ```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
``` ```
#### Use CLI ChatBox #### Use CLI ChatBox
...@@ -281,9 +282,3 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ...@@ -281,9 +282,3 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash ```bash
bash examples/extras/fsdp_qlora/train.sh bash examples/extras/fsdp_qlora/train.sh
``` ```
#### Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```
...@@ -52,7 +52,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ...@@ -52,7 +52,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO 训练 #### DPO/ORPO/SimPO 训练
...@@ -64,7 +64,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml ...@@ -64,7 +64,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
#### 多模态 DPO/ORPO/SimPO 训练 #### 多模态 DPO/ORPO/SimPO 训练
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
...@@ -168,7 +168,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 ...@@ -168,7 +168,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
...@@ -195,10 +195,11 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml ...@@ -195,10 +195,11 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
### 推理 LoRA 模型 ### 推理 LoRA 模型
#### 使用 vLLM+TP 批量推理 #### 使用 vLLM 多卡推理评估
``` ```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
``` ```
#### 使用命令行对话框 #### 使用命令行对话框
...@@ -281,9 +282,3 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ...@@ -281,9 +282,3 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash ```bash
bash examples/extras/fsdp_qlora/train.sh bash examples/extras/fsdp_qlora/train.sh
``` ```
#### 计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
### model ### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft adapter_name_or_path: saves/qwen2_5vl-7b/lora/sft
template: qwen2_vl template: qwen2_vl
trust_remote_code: true trust_remote_code: true
### export ### export
export_dir: output/qwen2_vl_lora_sft export_dir: output/qwen2_5vl_lora_sft
export_size: 5 export_size: 5
export_device: cpu # choices: [cpu, auto] export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
...@@ -23,7 +23,7 @@ preprocessing_num_workers: 16 ...@@ -23,7 +23,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_vl-7b/full/sft output_dir: saves/qwen2_5vl-7b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
......
...@@ -23,7 +23,7 @@ preprocessing_num_workers: 16 ...@@ -23,7 +23,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_vl-7b/lora/dpo output_dir: saves/qwen2_5vl-7b/lora/dpo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
......
...@@ -21,7 +21,7 @@ preprocessing_num_workers: 16 ...@@ -21,7 +21,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_vl-7b/lora/sft output_dir: saves/qwen2_5vl-7b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
......
...@@ -88,6 +88,14 @@ conflicts = [ ...@@ -88,6 +88,14 @@ conflicts = [
{ extra = "torch-npu" }, { extra = "torch-npu" },
{ extra = "vllm" }, { extra = "vllm" },
], ],
[
{ extra = "torch-npu" },
{ extra = "sglang" },
],
[
{ extra = "vllm" },
{ extra = "sglang" },
],
[ [
{ extra = "sglang" }, { extra = "sglang" },
{ extra = "minicpm_v" }, { extra = "minicpm_v" },
......
transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 transformers>=4.45.0,<=4.52.1,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0
datasets>=2.16.0,<=3.5.0 datasets>=2.16.0,<=3.6.0
accelerate>=0.34.0,<=1.6.0 accelerate>=0.34.0,<=1.7.0
peft>=0.14.0,<=0.15.1 peft>=0.14.0,<=0.15.2
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
tokenizers>=0.19.0,<=0.21.1 tokenizers>=0.19.0,<=0.21.1
gradio>=4.38.0,<=5.25.0 gradio>=4.38.0,<=5.29.1
scipy scipy
einops einops
sentencepiece sentencepiece
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import json import json
from typing import Optional from typing import Optional
import fire import fire
from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
...@@ -47,10 +49,15 @@ def vllm_infer( ...@@ -47,10 +49,15 @@ def vllm_infer(
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
default_system: Optional[str] = None,
enable_thinking: bool = True,
seed: Optional[int] = None, seed: Optional[int] = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
video_fps: float = 2.0,
video_maxlen: int = 128,
batch_size: int = 1024,
): ):
r"""Perform batch generation using vLLM engine, which supports tensor parallelism. r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
...@@ -69,6 +76,8 @@ def vllm_infer( ...@@ -69,6 +76,8 @@ def vllm_infer(
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
max_samples=max_samples, max_samples=max_samples,
preprocessing_num_workers=16, preprocessing_num_workers=16,
default_system=default_system,
enable_thinking=enable_thinking,
vllm_config=vllm_config, vllm_config=vllm_config,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
...@@ -83,38 +92,28 @@ def vllm_infer( ...@@ -83,38 +92,28 @@ def vllm_infer(
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
inputs, prompts, labels = [], [], [] engine_args = {
for sample in dataset_module["train_dataset"]: "model": model_args.model_name_or_path,
if sample["images"]: "trust_remote_code": True,
multi_modal_data = { "dtype": model_args.infer_dtype,
"image": template_obj.mm_plugin._regularize_images( "max_model_len": cutoff_len + max_new_tokens,
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels "tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
)["images"] "pipeline_parallel_size": pipeline_parallel_size,
} "disable_log_stats": True,
elif sample["videos"]: "enable_lora": model_args.adapter_name_or_path is not None,
multi_modal_data = { }
"video": template_obj.mm_plugin._regularize_videos( if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
sample["videos"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
)["videos"]
} if isinstance(model_args.vllm_config, dict):
elif sample["audios"]: engine_args.update(model_args.vllm_config)
audio_data = template_obj.mm_plugin._regularize_audios(
sample["audios"], llm = LLM(**engine_args)
sampling_rate=16000,
) # load datasets
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
else: train_dataset = dataset_module["train_dataset"]
multi_modal_data = None
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
)
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
...@@ -131,30 +130,68 @@ def vllm_infer( ...@@ -131,30 +130,68 @@ def vllm_infer(
else: else:
lora_request = None lora_request = None
engine_args = { # Store all results in these lists
"model": model_args.model_name_or_path, all_prompts, all_preds, all_labels = [], [], []
"trust_remote_code": True,
"dtype": model_args.infer_dtype, # Add batch process to avoid the issue of too many files opened
"max_model_len": cutoff_len + max_new_tokens, for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1, vllm_inputs, prompts, labels = [], [], []
"pipeline_parallel_size": pipeline_parallel_size, batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None, for j in range(len(batch["input_ids"])):
} if batch["images"][j] is not None:
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin": image = batch["images"][j]
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["images"]
}
elif batch["videos"][j] is not None:
video = batch["videos"][j]
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
video,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_fps=video_fps,
video_maxlen=video_maxlen,
)["videos"]
}
elif batch["audios"][j] is not None:
audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios(
audio,
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
skip_special_tokens=skip_special_tokens,
)
)
if isinstance(model_args.vllm_config, dict): results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
engine_args.update(model_args.vllm_config) preds = [result.outputs[0].text for result in results]
# Accumulate results
all_prompts.extend(prompts)
all_preds.extend(preds)
all_labels.extend(labels)
gc.collect()
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) # Write all results at once outside the loop
preds = [result.outputs[0].text for result in results]
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels): for text, pred, label in zip(all_prompts, all_preds, all_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
print("*" * 70) print("*" * 70)
print(f"{len(prompts)} generated results have been saved at {save_name}.") print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70) print("*" * 70)
......
...@@ -42,7 +42,7 @@ def get_console_scripts() -> list[str]: ...@@ -42,7 +42,7 @@ def get_console_scripts() -> list[str]:
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.5"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.5"],
...@@ -50,10 +50,9 @@ extra_require = { ...@@ -50,10 +50,9 @@ extra_require = {
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
"eetq": ["eetq"], "eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.8.4"], "vllm": ["vllm>=0.4.3,<=0.8.5"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"], "sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"apollo": ["apollo-torch"], "apollo": ["apollo-torch"],
......
...@@ -104,7 +104,6 @@ class HuggingfaceEngine(BaseEngine): ...@@ -104,7 +104,6 @@ class HuggingfaceEngine(BaseEngine):
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids( prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, prompt_ids,
...@@ -117,7 +116,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -117,7 +116,7 @@ class HuggingfaceEngine(BaseEngine):
) )
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool) attention_mask = torch.ones_like(inputs, dtype=torch.long)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment