# HuatuoGPT-o1 ## 论文 `HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs` * https://arxiv.org/pdf/2412.18925 ## 模型结构 该算法共有两种模型,分别是LLama3.1和Qwen2.5,两者都是decoder-only结构。 ![alt text](readme_imgs/arch.png) ## 算法原理 通过对相应的模型进行微调获取HuatuoGPT-o1,主要包括两阶段,分别是学习复杂推理及加强复杂推理。 stage1: 通过基于策略的搜索构建复杂的推理轨迹,并由验证器的反馈(正确或错误)进行引导。首先,大语言模型(LLM)初始化一个思维链(CoT)。如果验证器拒绝了当前的 CoT,模型将通过应用从以下策略中采样的方法扩展 CoT:回溯、探索新路径、验证和修正,直到提供正确答案。 stage2: 在掌握复杂推理技能后,强化学习(RL)进一步优化这一能力。具体来说,验证器提供的稀疏奖励通过近端策略优化算法(PPO)引导模型进行自我改进。 ![alt text](readme_imgs/alg.png) ## 环境配置 ### Docker(方法一) docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash pip install -r requirements.txt pip uninstall vllm pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl ### Dockerfile(方法二) docker build -t : . docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash pip install -r requirements.txt pip uninstall vllm pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl ### Anaconda (方法三) 1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/ DTK驱动:dtk24.04.3 python:python3.10 torch: 2.3.0 Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应 2、其它非特殊库参照requirements.txt安装 pip install -r requirements.txt pip uninstall vllm pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl ## 数据集 Medical Verifiable Problems: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-verifiable-problem) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/freedomintelligence/medical-o1-verifiable-problem) SFT Data in Stage 1: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) | [SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/freedomintelligence/medical-o1-reasoning-SFT) 除此之外,还可从选择题构建可验证的问题 ```bash python construct_verifiable_medical_problems.py --data_path data/demo_data.json --filter_data --model_name gpt-4o --api_key [your api key] ``` 为SFT搜索复杂推理路径 ```bash python search_for_complex_reasoning_path.py --data_path data/demo_data.json --efficient_search True --max_search_attempts 1 --max_search_depth 2 --model_name gpt-4o --api_key [your api key] ``` ## 训练 ### SFT ```bash accelerate launch --config_file ./configs/deepspeed_zero3.yaml \ --num_processes 8 \ --num_machines 1 \ --machine_rank 0 \ --deepspeed_multinode_launcher standard SFT_stage1.py \ --model_path [本地路径/huggingface模型名称] \ --data_path [本地路径/huggingface模型名称] ``` 注意:deepspeed代码需要进行简单的修改后才可使用,具体参考 https://github.com/microsoft/DeepSpeed/pull/5461/files ### RL ```bash accelerate launch \ --num_processes 8 \ --num_machines 1 \ --machine_rank 0 \ --config_file ./configs/deepspeed_zero3.yaml \ --deepspeed_multinode_launcher standard RL_stage2.py \ --model_name_or_path [FreedomIntelligence/HuatuoGPT-o1-8B | 本地模型地址] \ --reward_model_path [FreedomIntelligence/medical_o1_verifier_3B | 本地模型地址] \ --value_model_path [meta-llama/Llama-3.2-3B-Instruct | 本地模型地址] \ --dataset_name [FreedomIntelligence/medical-o1-verifiable-problem | 本地数据地址(json文件)] \ --response_length 1300 \ --temperature 0.5 \ --local_rollout_forward_batch_size 8 \ --num_ppo_epochs 3 \ --num_mini_batches 1 \ --total_episodes 20000 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 16 \ --bf16 True \ --output_dir ./ckpts \ --save_strategy steps \ --save_step 20 \ --save_total_limit 1 \ --eval_strategy steps \ --eval_steps 20 \ --kl_coef 0.03 \ --learning_rate 5e-7 \ --warmup_ratio 0.05 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --run_name ppo_medical_o1_8B \ --num_sample_generations -1 \ --report_to wandb ``` ## 推理 1、hf ```bash python inferences/simple_inference.py ``` 2、VLLM ```bash python inferences/vllm_offline.py --model_path /path/to/weight ``` ## result ![alt text](readme_imgs/result.png) ### 精度 与Nvidia GPU精度一致。 ## 应用场景 ### 算法类别 `对话问答` ### 热点应用行业 `医疗,电商,教育,广媒` ## 预训练权重 | | Backbone | Supported Languages | Link | | -------------------- | ------------ | ----- | --------------------------------------------------------------------- | | **HuatuoGPT-o1-8B** | LLaMA-3.1-8B | English | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-8B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-8B) | | **HuatuoGPT-o1-70B** | LLaMA-3.1-70B | English | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-70B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-70B) | | **HuatuoGPT-o1-7B** | Qwen2.5-7B | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-7B) | | **HuatuoGPT-o1-72B** | Qwen2.5-72B | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-72B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-72B) | | **Medical O1 Verifier 3B**| - | - | [HF Link](https://huggingface.co/FreedomIntelligence/medical_o1_verifier_3B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/medical_o1_verifier_3B)| | **Llama-3.2-3B-Instruct**| - | - | [HF Link](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)| ## 源码仓库及问题反馈 * https://developer.sourcefind.cn/codes/modelzoo/huatuogpt-o1 ## 参考资料 * https://github.com/FreedomIntelligence/HuatuoGPT-o1