# HuatuoGPT-o1

## 论文

`HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs`

* https://arxiv.org/pdf/2412.18925

## 模型结构

该算法共有两种模型，分别是LLama3.1和Qwen2.5，两者都是decoder-only结构。

![alt text](readme_imgs/arch.png)

## 算法原理

通过对相应的模型进行微调获取HuatuoGPT-o1，主要包括两阶段，分别是学习复杂推理及加强复杂推理。

stage1: 通过基于策略的搜索构建复杂的推理轨迹，并由验证器的反馈（正确或错误）进行引导。首先，大语言模型（LLM）初始化一个思维链（CoT）。如果验证器拒绝了当前的 CoT，模型将通过应用从以下策略中采样的方法扩展 CoT：回溯、探索新路径、验证和修正，直到提供正确答案。

stage2: 在掌握复杂推理技能后，强化学习（RL）进一步优化这一能力。具体来说，验证器提供的稀疏奖励通过近端策略优化算法（PPO）引导模型进行自我改进。

![alt text](readme_imgs/alg.png)


## 环境配置

### Docker（方法一）
    
    docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10

    docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash

    pip install -r requirements.txt

    pip uninstall vllm

    pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

    pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

### Dockerfile（方法二）

    docker build -t <IMAGE_NAME>:<TAG> .

    docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash

    pip install -r requirements.txt

    pip uninstall vllm

    pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

    pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

### Anaconda (方法三)

1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装：
https://developer.hpccube.com/tool/

    DTK驱动：dtk24.04.3
    python：python3.10
    torch: 2.3.0

Tips：以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应

2、其它非特殊库参照requirements.txt安装

    pip install -r requirements.txt

    pip uninstall vllm

    pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

    pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl

## 数据集

Medical Verifiable Problems: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-verifiable-problem) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/freedomintelligence/medical-o1-verifiable-problem)

SFT Data in Stage 1: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/freedomintelligence/medical-o1-reasoning-SFT)

除此之外，还可从选择题构建可验证的问题

```bash
python construct_verifiable_medical_problems.py --data_path  data/demo_data.json --filter_data --model_name gpt-4o --api_key [your api key]
```

为SFT搜索复杂推理路径

```bash
python search_for_complex_reasoning_path.py --data_path  data/demo_data.json --efficient_search True  --max_search_attempts 1 --max_search_depth 2 --model_name gpt-4o --api_key [your api key]
```

## 训练

### SFT

```bash
accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
    --num_processes 8  \
    --num_machines 1 \
    --machine_rank 0 \
    --deepspeed_multinode_launcher standard SFT_stage1.py \
    --model_path [本地路径/huggingface模型名称] \
    --data_path [本地路径/huggingface模型名称] 
```

注意：deepspeed代码需要进行简单的修改后才可使用，具体参考 https://github.com/microsoft/DeepSpeed/pull/5461/files

### RL

```bash
accelerate launch \
--num_processes 8 \
--num_machines 1 \
--machine_rank 0 \
--config_file ./configs/deepspeed_zero3.yaml \
--deepspeed_multinode_launcher standard RL_stage2.py \
--model_name_or_path [FreedomIntelligence/HuatuoGPT-o1-8B | 本地模型地址] \
--reward_model_path [FreedomIntelligence/medical_o1_verifier_3B | 本地模型地址] \
--value_model_path [meta-llama/Llama-3.2-3B-Instruct | 本地模型地址] \
--dataset_name  [FreedomIntelligence/medical-o1-verifiable-problem | 本地数据地址(json文件)] \
--response_length 1300 \
--temperature 0.5 \
--local_rollout_forward_batch_size 8 \
--num_ppo_epochs 3 \
--num_mini_batches 1 \
--total_episodes 20000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--bf16 True \
--output_dir ./ckpts \
--save_strategy steps \
--save_step 20 \
--save_total_limit 1 \
--eval_strategy steps \
--eval_steps 20 \
--kl_coef 0.03 \
--learning_rate 5e-7 \
--warmup_ratio 0.05 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ppo_medical_o1_8B \
--num_sample_generations -1 \
--report_to wandb
```

## 推理

1、hf

```bash
python inferences/simple_inference.py
```

2、VLLM

```bash
python inferences/vllm_offline.py --model_path /path/to/weight
```

## result

![alt text](readme_imgs/result.png)

### 精度

与Nvidia GPU精度一致。

## 应用场景

### 算法类别

`对话问答`

### 热点应用行业

`医疗,电商,教育,广媒`

## 预训练权重

|                      | Backbone     | Supported Languages | Link                                                                  |
| -------------------- | ------------ | ----- | --------------------------------------------------------------------- |
| **HuatuoGPT-o1-8B**  | LLaMA-3.1-8B  | English    | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-8B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-8B) |
| **HuatuoGPT-o1-70B** | LLaMA-3.1-70B | English    | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-70B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-70B) |
| **HuatuoGPT-o1-7B**  | Qwen2.5-7B   | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-7B) |
| **HuatuoGPT-o1-72B** | Qwen2.5-72B  | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-72B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-72B) |
| **Medical O1 Verifier 3B**| - | - | [HF Link](https://huggingface.co/FreedomIntelligence/medical_o1_verifier_3B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/medical_o1_verifier_3B)|
| **Llama-3.2-3B-Instruct**| - | - | [HF Link](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)|

## 源码仓库及问题反馈

* https://developer.sourcefind.cn/codes/modelzoo/huatuogpt-o1

## 参考资料

* https://github.com/FreedomIntelligence/HuatuoGPT-o1
