"vscode:/vscode.git/clone" did not exist on "8fdd7f90689dad437f55bcdd2f646fe7e969ab23"
README.md 7.64 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# HuatuoGPT-o1

## 论文

`HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs`

* https://arxiv.org/pdf/2412.18925

## 模型结构

该算法共有两种模型,分别是LLama3.1和Qwen2.5,两者都是decoder-only结构。

![alt text](readme_imgs/arch.png)

## 算法原理

通过对相应的模型进行微调获取HuatuoGPT-o1,主要包括两阶段,分别是学习复杂推理及加强复杂推理。

stage1: 通过基于策略的搜索构建复杂的推理轨迹,并由验证器的反馈(正确或错误)进行引导。首先,大语言模型(LLM)初始化一个思维链(CoT)。如果验证器拒绝了当前的 CoT,模型将通过应用从以下策略中采样的方法扩展 CoT:回溯、探索新路径、验证和修正,直到提供正确答案。

stage2: 在掌握复杂推理技能后,强化学习(RL)进一步优化这一能力。具体来说,验证器提供的稀疏奖励通过近端策略优化算法(PPO)引导模型进行自我改进。

![alt text](readme_imgs/alg.png)


## 环境配置

### Docker(方法一)
    
    docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10

    docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash

    pip install -r requirements.txt

    pip uninstall vllm

    pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

    pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

### Dockerfile(方法二)

    docker build -t <IMAGE_NAME>:<TAG> .

    docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash

    pip install -r requirements.txt

mashun1's avatar
mashun1 committed
50
51
52
53
54
55
    pip uninstall vllm

    pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

    pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

mashun1's avatar
mashun1 committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
### Anaconda (方法三)

1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/

    DTK驱动:dtk24.04.3
    python:python3.10
    torch: 2.3.0

Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应

2、其它非特殊库参照requirements.txt安装

    pip install -r requirements.txt

    pip uninstall vllm

    pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

    pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

## 数据集

mashun1's avatar
mashun1 committed
79
Medical Verifiable Problems: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-verifiable-problem) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/freedomintelligence/medical-o1-verifiable-problem)
mashun1's avatar
mashun1 committed
80

mashun1's avatar
mashun1 committed
81
SFT Data in Stage 1: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/freedomintelligence/medical-o1-reasoning-SFT)
mashun1's avatar
mashun1 committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

除此之外,还可从选择题构建可验证的问题

```bash
python construct_verifiable_medical_problems.py --data_path  data/demo_data.json --filter_data --model_name gpt-4o --api_key [your api key]
```

为SFT搜索复杂推理路径

```bash
python search_for_complex_reasoning_path.py --data_path  data/demo_data.json --efficient_search True  --max_search_attempts 1 --max_search_depth 2 --model_name gpt-4o --api_key [your api key]
```

## 训练

### SFT

```bash
accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
    --num_processes 8  \
    --num_machines 1 \
    --machine_rank 0 \
    --deepspeed_multinode_launcher standard SFT_stage1.py \
    --model_path [本地路径/huggingface模型名称] \
    --data_path [本地路径/huggingface模型名称] 
```

注意:deepspeed代码需要进行简单的修改后才可使用,具体参考 https://github.com/microsoft/DeepSpeed/pull/5461/files

### RL

```bash
accelerate launch \
--num_processes 8 \
--num_machines 1 \
--machine_rank 0 \
--config_file ./configs/deepspeed_zero3.yaml \
--deepspeed_multinode_launcher standard RL_stage2.py \
--model_name_or_path [FreedomIntelligence/HuatuoGPT-o1-8B | 本地模型地址] \
--reward_model_path [FreedomIntelligence/medical_o1_verifier_3B | 本地模型地址] \
--value_model_path [meta-llama/Llama-3.2-3B-Instruct | 本地模型地址] \
mashun1's avatar
mashun1 committed
123
--dataset_name  [FreedomIntelligence/medical-o1-verifiable-problem | 本地数据地址(json文件)] \
mashun1's avatar
mashun1 committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
--response_length 1300 \
--temperature 0.5 \
--local_rollout_forward_batch_size 8 \
--num_ppo_epochs 3 \
--num_mini_batches 1 \
--total_episodes 20000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--bf16 True \
--output_dir ./ckpts \
--save_strategy steps \
--save_step 20 \
--save_total_limit 1 \
--eval_strategy steps \
--eval_steps 20 \
--kl_coef 0.03 \
--learning_rate 5e-7 \
--warmup_ratio 0.05 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ppo_medical_o1_8B \
--num_sample_generations -1 \
--report_to wandb
```

## 推理

1、hf

```bash
python inferences/simple_inference.py
```

2、VLLM

```bash
python inferences/vllm_offline.py --model_path /path/to/weight
```

## result

![alt text](readme_imgs/result.png)

### 精度

与Nvidia GPU精度一致。

## 应用场景

### 算法类别

`对话问答`

### 热点应用行业

`医疗,电商,教育,广媒`

## 预训练权重

|                      | Backbone     | Supported Languages | Link                                                                  |
| -------------------- | ------------ | ----- | --------------------------------------------------------------------- |
| **HuatuoGPT-o1-8B**  | LLaMA-3.1-8B  | English    | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-8B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-8B) |
| **HuatuoGPT-o1-70B** | LLaMA-3.1-70B | English    | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-70B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-70B) |
| **HuatuoGPT-o1-7B**  | Qwen2.5-7B   | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-7B) |
| **HuatuoGPT-o1-72B** | Qwen2.5-72B  | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-72B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-72B) |
mashun1's avatar
mashun1 committed
189
190
| **Medical O1 Verifier 3B**| - | - | [HF Link](https://huggingface.co/FreedomIntelligence/medical_o1_verifier_3B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/medical_o1_verifier_3B)|
| **Llama-3.2-3B-Instruct**| - | - | [HF Link](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)|
mashun1's avatar
mashun1 committed
191
192
193

## 源码仓库及问题反馈

mashun1's avatar
mashun1 committed
194
195
* https://developer.sourcefind.cn/codes/modelzoo/huatuogpt-o1

mashun1's avatar
mashun1 committed
196
197
198
## 参考资料

* https://github.com/FreedomIntelligence/HuatuoGPT-o1