Commit 9bce6a82 authored by mashun1's avatar mashun1
Browse files

huatuogpt-o1

parents
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10
\ No newline at end of file
# HuatuoGPT-o1
## 论文
`HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs`
* https://arxiv.org/pdf/2412.18925
## 模型结构
该算法共有两种模型,分别是LLama3.1和Qwen2.5,两者都是decoder-only结构。
![alt text](readme_imgs/arch.png)
## 算法原理
通过对相应的模型进行微调获取HuatuoGPT-o1,主要包括两阶段,分别是学习复杂推理及加强复杂推理。
stage1: 通过基于策略的搜索构建复杂的推理轨迹,并由验证器的反馈(正确或错误)进行引导。首先,大语言模型(LLM)初始化一个思维链(CoT)。如果验证器拒绝了当前的 CoT,模型将通过应用从以下策略中采样的方法扩展 CoT:回溯、探索新路径、验证和修正,直到提供正确答案。
stage2: 在掌握复杂推理技能后,强化学习(RL)进一步优化这一能力。具体来说,验证器提供的稀疏奖励通过近端策略优化算法(PPO)引导模型进行自我改进。
![alt text](readme_imgs/alg.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10
docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
pip uninstall vllm
pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl
pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
DTK驱动:dtk24.04.3
python:python3.10
torch: 2.3.0
Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应
2、其它非特殊库参照requirements.txt安装
pip install -r requirements.txt
pip uninstall vllm
pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl
pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl
## 数据集
Medical Verifiable Problems: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-verifiable-problem) | [SCNet高速下载通道]()
SFT Data in Stage 1: [huggingface](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) | [SCNet高速下载通道]()
除此之外,还可从选择题构建可验证的问题
```bash
python construct_verifiable_medical_problems.py --data_path data/demo_data.json --filter_data --model_name gpt-4o --api_key [your api key]
```
为SFT搜索复杂推理路径
```bash
python search_for_complex_reasoning_path.py --data_path data/demo_data.json --efficient_search True --max_search_attempts 1 --max_search_depth 2 --model_name gpt-4o --api_key [your api key]
```
## 训练
### SFT
```bash
accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
--num_processes 8 \
--num_machines 1 \
--machine_rank 0 \
--deepspeed_multinode_launcher standard SFT_stage1.py \
--model_path [本地路径/huggingface模型名称] \
--data_path [本地路径/huggingface模型名称]
```
注意:deepspeed代码需要进行简单的修改后才可使用,具体参考 https://github.com/microsoft/DeepSpeed/pull/5461/files
### RL
```bash
accelerate launch \
--num_processes 8 \
--num_machines 1 \
--machine_rank 0 \
--config_file ./configs/deepspeed_zero3.yaml \
--deepspeed_multinode_launcher standard RL_stage2.py \
--model_name_or_path [FreedomIntelligence/HuatuoGPT-o1-8B | 本地模型地址] \
--reward_model_path [FreedomIntelligence/medical_o1_verifier_3B | 本地模型地址] \
--value_model_path [meta-llama/Llama-3.2-3B-Instruct | 本地模型地址] \
--dataset_name [FreedomIntelligence/medical-o1-verifiable-problem | 本地数据地址] \
--response_length 1300 \
--temperature 0.5 \
--local_rollout_forward_batch_size 8 \
--num_ppo_epochs 3 \
--num_mini_batches 1 \
--total_episodes 20000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--bf16 True \
--output_dir ./ckpts \
--save_strategy steps \
--save_step 20 \
--save_total_limit 1 \
--eval_strategy steps \
--eval_steps 20 \
--kl_coef 0.03 \
--learning_rate 5e-7 \
--warmup_ratio 0.05 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ppo_medical_o1_8B \
--num_sample_generations -1 \
--report_to wandb
```
## 推理
1、hf
```bash
python inferences/simple_inference.py
```
2、VLLM
```bash
python inferences/vllm_offline.py --model_path /path/to/weight
```
## result
![alt text](readme_imgs/result.png)
### 精度
与Nvidia GPU精度一致。
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`医疗,电商,教育,广媒`
## 预训练权重
| | Backbone | Supported Languages | Link |
| -------------------- | ------------ | ----- | --------------------------------------------------------------------- |
| **HuatuoGPT-o1-8B** | LLaMA-3.1-8B | English | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-8B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-8B) |
| **HuatuoGPT-o1-70B** | LLaMA-3.1-70B | English | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-70B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-70B) |
| **HuatuoGPT-o1-7B** | Qwen2.5-7B | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-7B) |
| **HuatuoGPT-o1-72B** | Qwen2.5-72B | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-72B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/freedomintelligence/HuatuoGPT-o1-72B) |
## 源码仓库及问题反馈
## 参考资料
* https://github.com/FreedomIntelligence/HuatuoGPT-o1
# HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs
<div align="center">
<h3>
HuatuoGPT-o1
</h3>
</div>
<p align="center">
📃 <a href="https://arxiv.org/pdf/2412.18925" target="_blank">Paper</a> |🤗 <a href="https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B" target="_blank">HuatuoGPT-o1-7B</a> |🤗 <a href="https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-8B" target="_blank">HuatuoGPT-o1-8B</a> | 🤗 <a href="https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-70B" target="_blank">HuatuoGPT-o1-70B</a> | 📚 <a href="https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT" target="_blank">Data</a>
</p>
## ⚡ Introduction
Hello! Welcome to the repository for [HuatuoGPT-o1](https://arxiv.org/pdf/2412.18925)!
<div align=center>
<img src="assets/pic1.jpg" width = "90%" alt="HuatuoGPT-o1" align=center/>
</div>
**HuatuoGPT-o1** is a medical LLM designed for advanced medical reasoning. It can identify mistakes, explore alternative strategies, and refine its answers. By leveraging verifiable medical problems and a specialized medical verifier, it advances reasoning through:
- Using the verifier to guide the search for a complex reasoning trajectory for fine-tuning LLMs.
- Applying reinforcement learning (PPO) with verifier-based rewards to enhance complex reasoning further.
We open-sourced our models, data, and code here.
## 👨‍⚕️ Model
- **Model Access**
| | Backbone | Supported Languages | Link |
| -------------------- | ------------ | ----- | --------------------------------------------------------------------- |
| **HuatuoGPT-o1-8B** | LLaMA-3.1-8B | English | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-8B) |
| **HuatuoGPT-o1-70B** | LLaMA-3.1-70B | English | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-70B) |
| **HuatuoGPT-o1-7B** | Qwen2.5-7B | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B) |
| **HuatuoGPT-o1-72B** | Qwen2.5-72B | English & Chinese | [HF Link](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-72B) |
- **Deploy**
HuatuoGPT-o1 can be used just like `Llama-3.1-8B-Instruct`. You can deploy it with tools like [vllm](https://github.com/vllm-project/vllm) or [Sglang](https://github.com/sgl-project/sglang), or perform direct inference:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("FreedomIntelligence/HuatuoGPT-o1-8B",torch_dtype="auto",device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("FreedomIntelligence/HuatuoGPT-o1-8B")
input_text = "How to stop a cough?"
messages = [{"role": "user", "content": input_text}]
inputs = tokenizer(tokenizer.apply_chat_template(messages, tokenize=False,add_generation_prompt=True
), return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=2048)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
HuatuoGPT-o1 adopts a *thinks-before-it-answers* approach, with outputs formatted as:
```
## Thinking
[Reasoning process]
## Final Response
[Output]
```
## 📚 Data
- **Data Access**
| Data | Description | Link |
| -------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------------------------------------------- |
| Medical Verifiable Problems | Open-ended medical problems sourced from challenging medical exams, paired with ground-truth answers. | [Link](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-verifiable-problem) |
| SFT Data in Stage 1 | Fine-tuning data generated using GPT-4o, including complex chains of thought (**Complex CoT**) and output (**Response**). | [Link](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) |
- **Data Construction**
We provide scripts to construct verifiable problems and searching reasoning paths.
**1. Constructing Verifiable Problems from Multi-choice Questions.**
```bash
python construct_verifiable_medical_problems.py --data_path data/demo_data.json --filter_data --model_name gpt-4o --api_key [your api key]
```
**2. Searching Complex Reasoning Paths for SFT**
```bash
python search_for_complex_reasoning_path.py --data_path data/demo_data.json --efficient_search True --max_search_attempts 1 --max_search_depth 2 --model_name gpt-4o --api_key [your api key]
```
## 🚀 Training
- **Stage 1: Supervised Fine-Tuning (SFT)**
Fine-tune the model on an 8-GPU setup:
```bash
accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
--num_processes 8 \
--num_machines 1 \
--machine_rank 0 \
--deepspeed_multinode_launcher standard SFT_stage1.py \
--model_path [meta-llama/Llama-3.1-8B-Instruct] \
--data_path [FreedomIntelligence/medical-o1-reasoning-SFT]
```
- **Stage 2: Reinforcement Learning (RL)**
We provide a simple PPO script using the [trl](https://github.com/huggingface/trl) library. Below is an example for training an 8B model with PPO on an 8-GPU A100 machine. Ensure you first download our [medical verifier](https://huggingface.co/FreedomIntelligence/medical_o1_verifier_3B) as the reward model.
```bash
accelerate launch \
--num_processes 8 \
--num_machines 1 \
--machine_rank 0 \
--config_file ./configs/deepspeed_zero3.yaml \
--deepspeed_multinode_launcher standard RL_stage2.py \
--model_name_or_path [FreedomIntelligence/HuatuoGPT-o1-8B] \
--reward_model_path [FreedomIntelligence/medical_o1_verifier_3B] \
--value_model_path [meta-llama/Llama-3.2-3B-Instruct] \
--dataset_name [FreedomIntelligence/medical-o1-verifiable-problem]\
--response_length 1300 \
--temperature 0.5 \
--local_rollout_forward_batch_size 8 \
--num_ppo_epochs 3 \
--num_mini_batches 1 \
--total_episodes 20000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--bf16 True \
--output_dir ./ckpts \
--save_strategy steps \
--save_step 20 \
--save_total_limit 1 \
--eval_strategy steps \
--eval_steps 20 \
--kl_coef 0.03 \
--learning_rate 5e-7 \
--warmup_ratio 0.05 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ppo_medical_o1_8B \
--num_sample_generations -1 \
--report_to wandb
```
## 🧐 Evaluation
1. You first need to install [Sglang](https://github.com/sgl-project/sglang). After installation, deploy the model you want to test using Sglang with the following command:
```bash
log_num=0
model_name="FreedomIntelligence/HuatuoGPT-o1-8B" # Path to the model you are deploying
port=28${log_num}35
CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server --model-path $model_name --port $port --mem-fraction-static 0.8 --dp 1 --tp 1 > sglang${log_num}.log 2>&1 &
```
2. Wait for the model to be deployed. After deployment, you can run the following code for evaluation. We use prompts that allow the model to respond freely. We find that the extracted results are consistently reliable and broadly cover the intended scope. You can also set the `--strict_prompt` option to use stricter prompts for more precise answer extraction.
```bash
python evaluation/eval.py --model_name $model_name --eval_file evaluation/data/eval_data.json --port $port
```
3. After completing the evaluation, run the following code to stop the Sglang service and release GPU memory.
```bash
bash evaluation/kill_sglang_server.sh
```
The evaluation code above can be used to test most models supported by Sglang.
## 🩺 HuatuoGPT Series
Explore our HuatuoGPT series:
- [**HuatuoGPT**](https://github.com/FreedomIntelligence/HuatuoGPT): Taming Language Models to Be a Doctor
- [**HuatuoGPT-II**](https://github.com/FreedomIntelligence/HuatuoGPT-II): One-stage Training for Medical Adaptation of LLMs
- [**HuatuoGPT-Vision**](https://github.com/FreedomIntelligence/HuatuoGPT-Vision): Injecting Medical Visual Knowledge into Multimodal LLMs at Scale
- [**CoD (Chain-of-Diagnosis)**](https://github.com/FreedomIntelligence/Chain-of-Diagnosis): Towards an Interpretable Medical Agent using Chain of Diagnosis
- [**HuatuoGPT-o1**](https://github.com/FreedomIntelligence/HuatuoGPT-o1): Towards Medical Complex Reasoning with LLMs
## 📖 Citation
```
@misc{chen2024huatuogpto1medicalcomplexreasoning,
title={HuatuoGPT-o1, Towards Medical Complex Reasoning with LLMs},
author={Junying Chen and Zhenyang Cai and Ke Ji and Xidong Wang and Wanlong Liu and Rongsheng Wang and Jianye Hou and Benyou Wang},
year={2024},
eprint={2412.18925},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.18925},
}
```
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=FreedomIntelligence/HuatuoGPT-o1&type=Date)](https://star-history.com/#FreedomIntelligence/HuatuoGPT-o1&Date)
import os
import warnings
from dataclasses import dataclass
import wandb
import torch
from datasets import load_dataset,load_from_disk
from transformers import AutoModelForSequenceClassification, AutoTokenizer,PreTrainedTokenizerBase
import json,random
from trl import (
ModelConfig,
ScriptArguments
)
from ppo_utils.ppo_config_medo1 import PPOConfig
from ppo_utils.ppo_trainer_medo1 import PPOTrainer
os.environ["WANDB_MODE"] = "offline"
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser
)
class ppo_dataset(torch.utils.data.Dataset):
def __init__(self, data, tokenizer, max_length = 1000,debug = 0):
self.tokenizer = tokenizer
self.data = data
self.max_length = max_length
newdata = []
for da in self.data:
if len(da['Open-ended Verifiable Question']) > 0 and len(da['Ground-True Answer']) > 0:
newdata.append({'question':da['Open-ended Verifiable Question'],'answer':da['Ground-True Answer']})
print(len(self.data),' -> ',len(newdata))
self.data = newdata
self.debug = debug
def __getitem__(self, index):
return self.data[index]
def get_prompt(self,da):
message = [{"role": "user", "content": da['question']}]
prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
input_token = self.tokenizer(
prompt,
padding=False,
truncation=False,
add_special_tokens=False,
)
da['input_ids'] = input_token["input_ids"]
return da
def collate_fn(self, batch):
data = [ self.get_prompt(da) for da in batch]
input_ids = [item["input_ids"] for item in data]
question = [item["question"] for item in data]
answer = [item["answer"] for item in data]
max_len = max(len(x) for x in input_ids)
max_len = min(max_len,self.max_length)
input_ids = [ [self.tokenizer.pad_token_id]*(max_len-len(item)) + item[:max_len] for item in input_ids]
if self.debug > 0:
print('[input_ids]',self.tokenizer.decode(input_ids[-1]))
print('[question]',question[-1])
print('[answer]',answer[-1])
self.debug -= 1
return {
"input_ids": torch.LongTensor(input_ids),
"question": question,
"answer": answer
}
def __len__(self):
return len(self.data)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
output_dir = training_args.output_dir
run_name = training_args.run_name
if run_name not in output_dir:
output_dir = os.path.join(output_dir,run_name)
training_args.output_dir = output_dir
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, attn_implementation="flash_attention_2",num_labels=2
)
value_model = AutoModelForSequenceClassification.from_pretrained(
training_args.value_model_path, trust_remote_code=model_config.trust_remote_code, attn_implementation="flash_attention_2",num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,attn_implementation="flash_attention_2")
policy = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,attn_implementation="flash_attention_2")
reward_tokenizer = AutoTokenizer.from_pretrained(training_args.reward_model_path)
if '<|eot_id|>' in tokenizer.vocab:
assert '<|end_of_text|>' in tokenizer.vocab
tokenizer.pad_token = '<|end_of_text|>'
tokenizer.pad_token_id = tokenizer.encode('<|end_of_text|>',add_special_tokens=False)[0]
assert tokenizer.pad_token_id != tokenizer.eos_token_id
training_args.stop_token_id = tokenizer.eos_token_id
eval_ratio = 0.1
eval_max_num = 200
with open(script_args.dataset_name) as f:
data = json.load(f)
random.shuffle(data)
eval_num = min(int(len(data) * eval_ratio),eval_max_num)
train_dataset = ppo_dataset(data[eval_num:],tokenizer, debug = 1)
eval_dataset = ppo_dataset(data[:eval_num],tokenizer)
trainer = PPOTrainer(
config=training_args,
processing_class=tokenizer,
reward_processing_class = reward_tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator = train_dataset.collate_fn
)
trainer.train()
\ No newline at end of file
import os
import json
import torch
import logging
import argparse
from tqdm import tqdm
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
import wandb
from accelerate import Accelerator
from transformers import set_seed, get_cosine_schedule_with_warmup
import shutil
import json
import traceback
from jinja2 import Template
from transformers import AutoModelForCausalLM, AutoTokenizer
os.umask(0)
logger = logging.getLogger(__name__)
logging.basicConfig(level='INFO')
class Train_dataset(torch.utils.data.Dataset):
def __init__(self, config, tokenizer):
self.config = config
self.tokenizer = tokenizer
with open(config.data_path) as f:
self.data = json.load(f)
newdata = []
for da in self.data:
newdata.append(da)
print('过滤掉',len(self.data),len(newdata))
self.data = newdata
self.max_seq_len = self.config.max_seq_len
self.debug = 0
# 如果从Base LLMs训练,选择 llama3-instruct作为模版
chat_template_llama3 = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
if not tokenizer.chat_template:
tokenizer.chat_template = chat_template_llama3
self.template = Template(tokenizer.chat_template)
def __getitem__(self, index):
return self.data[index]
def get_response(self,da):
temp = '## Thinking\n\n{}\n\n## Final Response\n\n{}'
return temp.format(da['Complex_CoT'],da['Response'])
def get_prompt(self,da):
q = da['Question']
a = self.get_response(da)
assert q is not None and a is not None, f'q:{q} a:{a}'
input = self.template.render(messages=[{"role": "user", "content": q},{"role": "assistant", "content": a}],bos_token=self.tokenizer.bos_token,add_generation_prompt=False)
input_ids = self.tokenizer.encode(input,add_special_tokens= False)
query = self.template.render(messages=[{"role": "user", "content": q}],bos_token=self.tokenizer.bos_token,add_generation_prompt=True)
query_ids = self.tokenizer.encode(query,add_special_tokens= False)
labels = [-100]*len(query_ids) + input_ids[len(query_ids):]
assert len(labels) == len(input_ids)
return {"input_ids": input_ids[-self.max_seq_len:], "labels": labels[-self.max_seq_len:]}
def collate_fn(self, batch):
data = [ self.get_prompt(da) for da in batch]
input_ids = [item["input_ids"] for item in data]
labels = [item["labels"] for item in data]
max_len = max(len(x) for x in input_ids)
max_len = min(max_len,self.max_seq_len)
input_ids = [ item[:max_len] + [self.tokenizer.eos_token_id]*(max_len-len(item)) for item in input_ids]
labels = [ item[:max_len] + [-100]*(max_len-len(item)) for item in labels]
if self.debug < 3:
print('input_ids',self.tokenizer.decode(input_ids[-1]))
print('labels',self.tokenizer.decode([0 if x == -100 else x for x in labels[-1]]))
self.debug += 1
return {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(labels),
}
def __len__(self):
return len(self.data)
class SFTMetric:
def __init__(self, device):
self.n_step = 0
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_loss = torch.Tensor([0]).to(device=device)
self.world_size = dist.get_world_size()
def __call__(self, logits, labels, loss):
return self.update(logits, labels, loss)
def update(self, logits, labels, loss):
self.n_step += 1
with torch.no_grad():
shift_preds = logits[..., :-1, :].argmax(dim=-1)
shift_labels = labels[..., 1:]
self.right += (shift_preds == shift_labels).masked_fill(shift_labels.eq(-100), 0).sum().item()
self.total += (shift_labels != -100).sum().item()
self.total_loss += loss.item()
def get_metric(self, reset=True):
dist.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(self.total_loss, op=torch.distributed.ReduceOp.SUM)
acc = (self.right / self.total).item()
loss = self.total_loss.item() / (self.world_size * self.n_step)
if reset:
self.n_step = 0
self.right.fill_(0)
self.total.fill_(0)
self.total_loss.fill_(0)
return acc, loss
def train(args):
accelerator = Accelerator(mixed_precision='bf16', gradient_accumulation_steps=args.gradient_accumulation_steps)
if accelerator.is_main_process:
wandb.init(project = args.experiment_name, config=args, dir=args.log_dir, mode="offline")
accelerator.print(f'args:\n{args}')
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_bsz_per_gpu
accelerator.state.deepspeed_plugin.deepspeed_config['train_batch_size'] = args.train_bsz_per_gpu*dist.get_world_size()*accelerator.gradient_accumulation_steps
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
# open gradient checkpointing
model.gradient_checkpointing_enable()
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
train_dataset = Train_dataset(args, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_bsz_per_gpu, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)
num_training_steps = int(len(train_dataloader) * (args.n_epochs)) // accelerator.gradient_accumulation_steps // dist.get_world_size()
lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_rates * num_training_steps), num_training_steps=num_training_steps)
accelerator.print(f'gradient_accumulation_steps:{accelerator.gradient_accumulation_steps} data_path:{args.data_path} lr:{args.learning_rate} num_training_steps:{num_training_steps}')
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
start_epoch = 0
start_step = 0
global_step = 0
metric = SFTMetric(device=torch.cuda.current_device())
def save_checkpoint(epoch, step, global_step):
save_dir = os.path.join(args.output_dir, f"checkpoint-{epoch}-{global_step}")
if accelerator.is_main_process:
checkpoint_files = os.listdir(args.output_dir)
checkpoint_files = [file for file in checkpoint_files if file.startswith("checkpoint-")]
num_checkpoints = len(checkpoint_files)
if args.max_ckpts>0:
if num_checkpoints >= args.max_ckpts:
checkpoint_files.sort(key=lambda x: os.path.getctime(os.path.join(args.output_dir, x)))
oldest_checkpoint = checkpoint_files[0]
shutil.rmtree(os.path.join(args.output_dir, oldest_checkpoint))
os.makedirs(save_dir, exist_ok=True)
output_dir = os.path.join(save_dir, 'tfmr')
if accelerator.state.deepspeed_plugin.zero_stage!=3:
model.save_pretrained(output_dir,state_dict=accelerator.get_state_dict(model))
tokenizer.save_pretrained(output_dir)
copy_files = []
for item in os.listdir(args.model_path):
if os.path.exists(os.path.join(output_dir,item)):
continue
if item.startswith("pytorch_model") and item.endswith(".bin"):
continue
if item.endswith(".index.json") or item.endswith(".safetensors"):
continue
s = os.path.join(args.model_path, item)
if os.path.isfile(s):
shutil.copy(s, os.path.join(output_dir,item))
copy_files.append(item)
print(f'huggingface model save in {output_dir}, copy file:{copy_files}')
if accelerator.state.deepspeed_plugin.zero_stage==3:
unwrap_model = accelerator.unwrap_model(model)
unwrap_model.save_pretrained(os.path.join(save_dir, f'tfmr'),is_main_process=accelerator.is_main_process,save_function=accelerator.save,state_dict=accelerator.get_state_dict(model))
accelerator.wait_for_everyone()
accelerator.save({"epoch": epoch, "step": step, "global_step": global_step}, os.path.join(save_dir, "training_state.pt"))
accelerator.print(f'checkpoint checkpoint-{epoch}-{global_step} is saved...')
accelerator.print(accelerator.deepspeed_config)
model.train()
for epoch in range(start_epoch, args.n_epochs):
train_dataloader_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) if accelerator.is_main_process else enumerate(train_dataloader)
for batch_cnt, batch in train_dataloader_iterator:
if epoch==start_epoch and batch_cnt<start_step:
continue
if batch_cnt == 1 and epoch == 0:
torch.cuda.empty_cache()
input_ids=batch['input_ids']
labels=batch['labels']
output = model(input_ids=input_ids, labels=labels, return_dict=True,use_cache=False)
loss = output.loss
metric(output.logits, labels, loss)
acc, train_loss = metric.get_metric()
accelerator.backward(loss)
if (global_step+1) % accelerator.gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
if accelerator.is_main_process:
train_dataloader_iterator.set_postfix(epoch=epoch, current_step=batch_cnt, total_step=len(train_dataloader), skip=accelerator.optimizer_step_was_skipped, loss=round(train_loss, 3), acc=round(acc, 3), length=len(input_ids[0]), lr=lr_scheduler.get_last_lr()[0])
if global_step % 3 == 0 and accelerator.is_main_process:
wandb.log({
'skip': int(accelerator.optimizer_step_was_skipped),
'loss': train_loss,
'acc': acc,
'lr': lr_scheduler.get_last_lr()[0]
}, step=global_step)
accelerator.wait_for_everyone()
save_checkpoint(epoch, batch_cnt, global_step)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Args of sft')
# Experiment Args
parser.add_argument('--experiment_name', type=str,default='sft_stage1')
# Model Args
parser.add_argument('--model_path', required=True, type=str)
# Data Args
parser.add_argument('--data_path', required=True, type=str)
# Training Args
parser.add_argument('--output_dir', default='./ckpts', type=str)
parser.add_argument('--max_ckpts', default=2, type=int)
parser.add_argument('--log_dir', default='./train_logs', type=str)
parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--gradient_checkpointing', action='store_true')
parser.add_argument('--gradient_accumulation_steps', default=8, type=int)
parser.add_argument('--train_bsz_per_gpu', default=2, type=int)
parser.add_argument('--weight_decay', default=0.1, type=float)
parser.add_argument('--learning_rate', default=5e-6, type=float)
parser.add_argument('--warmup_rates', default=0.05, type=float)
parser.add_argument('--n_epochs', default=3, type=int)
# Other Args
parser.add_argument('--seed', default=42, type=int)
args = parser.parse_args()
args.log_dir = os.path.join(args.log_dir,args.experiment_name)
args.output_dir = os.path.join(args.output_dir,args.experiment_name)
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.output_dir, exist_ok=True)
set_seed(args.seed)
train(args)
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: true
\ No newline at end of file
import os
import random
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from retrying import retry
import argparse
import traceback
import re
import requests
class GPT:
def __init__(self, model_name, api_url, api_key):
self.model_name = model_name
self.api_url = api_url
self.api_key = api_key
print(f"Using model: {self.model_name}")
def call(self, content, additional_args={}):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": self.model_name,
"messages": [{'role': 'user', 'content': content}],
**additional_args,
}
response = requests.post(self.api_url, headers=headers, json=payload)
response_data = response.json()
if 'error' in response_data:
raise ValueError(f"API Error: {response_data}")
return response_data['choices'][0]['message']['content']
@retry(wait_fixed=3000, stop_max_attempt_number=3)
def retry_call(self, content, additional_args={"max_tokens": 8192}):
return self.call(content, additional_args)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True, help="Path to the input JSON data file.")
parser.add_argument("--filter_data", action='store_true', help="Enable filtering of questions with LLMs.")
parser.add_argument("--model_name", type=str, default="gpt-4", help="Name of the GPT model to use.")
parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key.")
parser.add_argument("--api_url", type=str, default="https://api.openai.com/v1/chat/completions", help="OpenAI API URL.")
parser.add_argument("--num_process", type=int, default=10, help="Number of parallel processes.")
parser.add_argument("--limit_num", type=int, help="Limit the number of processed items.")
return parser.parse_args()
def extract_bracket_content(text):
# Extract content between the first '{' and the last '}'
match = re.search(r'\{.*\}', text, re.DOTALL)
return match.group(0) if match else None
def parse_gpt_response(response):
try:
if not response.startswith('{'):
response = extract_bracket_content(response)
parsed_data = json.loads(response.replace('\n', ''))
assert len(parsed_data) == 2, "Response JSON should contain exactly two keys."
assert isinstance(parsed_data["Open-ended Verifiable Question"], str), "Open-ended Question must be a string."
assert isinstance(parsed_data["Ground-True Answer"], str), "Ground-True Answer must be a string."
return True, parsed_data
except Exception as e:
print(f"Error parsing GPT response: {e}")
return False, None
def process_single_item(item, gpt_instance, save_directory, filter_prompt, reformat_prompt, filter_enabled):
try:
max_retries = 2
save_path = os.path.join(save_directory, f"{item['process_id']}.json")
# Generate options string for the question
item['options_str'] = '\n'.join([f"{key}. {value}" for key, value in item['options'].items()])
question_text = f"{item['question']}\n{item['options_str']}"
# Filter questions if enabled
if filter_enabled:
filter_query = filter_prompt.format(question_text, item['answer'])
item['gpt_filter_query'] = filter_query
response = gpt_instance.retry_call(filter_query)
item['gpt_filter_response'] = response
if 'pass' not in response.lower():
with open(save_path, 'w', encoding='utf-8') as file:
json.dump(item, file, ensure_ascii=False, indent=2)
return 1
# Reformat questions into open-ended format
reformat_query = reformat_prompt.format(question_text, item['answer'])
item['gpt_reformat_query'] = reformat_query
for _ in range(max_retries):
response = gpt_instance.retry_call(reformat_query)
item['gpt_reformat_response'] = response
valid, parsed_data = parse_gpt_response(response)
if valid:
item["Open-ended Verifiable Question"] = parsed_data["Open-ended Verifiable Question"]
item["Ground-True Answer"] = parsed_data["Ground-True Answer"]
break
with open(save_path, 'w', encoding='utf-8') as file:
json.dump(item, file, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error processing item {item['process_id']}: {e}")
return 1
def merge_saved_files(directory):
_, _, filenames = next(os.walk(directory))
json_files = [f for f in filenames if f.endswith('.json')]
merged_data = []
for file in json_files:
try:
with open(os.path.join(directory, file), 'r', encoding='utf-8') as f:
data = json.load(f)
assert 'Open-ended Verifiable Question' in data or 'gpt_filter_response' in data or 'gpt4_response_filter' in data
merged_data.append(data)
except Exception as e:
# traceback.print_exc()
print(f"Error merging file {file}: {e}")
return merged_data
def deduplicate_data(data, processed_data):
processed_ids = {item['process_id'] for item in processed_data}
return [item for item in data if item['process_id'] not in processed_ids]
def main():
args = parse_arguments()
# Load input data
with open(args.data_path, 'r') as file:
input_data = json.load(file)
# Assign unique process IDs to each item
for idx, item in enumerate(input_data, start=1):
item['process_id'] = idx
if args.limit_num:
input_data = input_data[:args.limit_num]
print(f"Loaded {len(input_data)} items.")
# Define task and save directory
task_name = os.path.splitext(os.path.basename(args.data_path))[0]
save_directory = os.path.join('output_data', task_name)
os.makedirs(save_directory, exist_ok=True)
gpt_instance = GPT(model_name=args.model_name, api_url=args.api_url, api_key=args.api_key)
filter_prompt = """<Multiple-choice Question>
{}
Correct Answer: {}
</Multiple-choice Question>
You are an expert in filtering and evaluating multiple-choice questions for advanced reasoning tasks. Your job is to evaluate a given question and determine whether it meets the following criteria:
1. **Depth of Reasoning:** The question should require deeper reasoning. If the question appears too simple, mark it as "Too Simple".
2. **Unambiguous Correct Answer:** The question must have a unique and unambiguous correct answer. If the question asks for "incorrect options" or allows for multiple correct answers, mark it as "Ambiguous Answer".
3. **Open-Ended Reformulation Feasibility:** The question should be suitable for reformatting into an open-ended format. If the question cannot be easily reformulated into an open-ended problem and a clear ground-truth answer, mark it as "Not Reformulatable".
For each question, provide one of the following evaluations:
- "Pass" (The question meets all the criteria.)
- "Too Simple"
- "Ambiguous Answer"
- "Not Reformulatable" """
reformat_prompt = """I will provide you with a multiple-choice question, and your task is to rewrite it into an open-ended question, along with a Ground-True Answer. The requirements are:
1. The question must be specific, targeting the point being tested in the original multiple-choice question. Ensure it is open-ended, meaning no options are provided, but there must be a definitive Ground-True Answer.
2. Based on the correct answer from the original question, provide a concise Ground-True Answer. The answer should allow for precise matching to determine whether the model's response is correct.
Here is the multiple-choice question for you to rewrite:
<Multiple-choice Question>
{}
Correct Answer: {}
</Multiple-choice Question>
Please output the result in the following JSON format:
```json
{{
"Open-ended Verifiable Question": "...",
"Ground-True Answer": "..."
}}
```"""
# Merge previously processed files
processed_data = merge_saved_files(save_directory)
print(f"Previously processed items: {len(processed_data)}")
input_data = deduplicate_data(input_data, processed_data)
print(f"Items remaining for processing: {len(input_data)}")
# Process data using a thread pool
with ThreadPoolExecutor(max_workers=args.num_process) as executor:
list(tqdm(executor.map(lambda item: process_single_item(item, gpt_instance, save_directory, filter_prompt, reformat_prompt, args.filter_data), input_data), total=len(input_data), desc="Processing Items", unit="item"))
# Merge and save final output
final_data = merge_saved_files(save_directory)
output_path = f"{task_name}_final_{len(final_data)}.json"
print(f"Processed {len(final_data)} items. Saving to {output_path}")
with open(output_path, 'w', encoding='utf-8') as file:
json.dump(final_data, file, ensure_ascii=False, indent=2)
if __name__ == '__main__':
main()
This diff is collapsed.
This diff is collapsed.
import argparse
from tqdm import tqdm
import argparse
import openai
from jinja2 import Template
import os
import json
from transformers import AutoTokenizer
from jinja2 import Template
from scorer import get_results
def postprocess_output(pred):
pred = pred.replace("</s>", "")
if len(pred) > 0 and pred[0] == " ":
pred = pred[1:]
return pred
def load_file(input_fp):
with open(input_fp, 'r') as f:
data = json.load(f)
input_data = []
if isinstance(data, list):
data = {'normal': data}
for k,v in data.items():
for da in v:
da['source'] = k
input_data.extend(v)
return input_data
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str,
default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument('--eval_file', type=str, required=True)
parser.add_argument('--max_new_tokens', type=int, default=2000)
parser.add_argument('--max_tokens', type=int, default=-1)
parser.add_argument('--use_chat_template',type=bool, default=True)
parser.add_argument('--strict_prompt', action="store_true")
parser.add_argument('--task', type=str,default='api')
parser.add_argument('--port', type=int, default=30000)
parser.add_argument('--batch_size', type=int, default=1024)
args = parser.parse_args()
print(f"Using local API server at port {args.port}")
client = openai.Client(
base_url=f"http://127.0.0.1:{args.port}/v1", api_key="EMPTY")
if args.use_chat_template:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, padding_side='left')
template = Template(tokenizer.chat_template)
def call_model(prompts, model, max_new_tokens=50, print_example =False):
temperature = 0.5
if print_example:
print("Example:")
print(prompts[1])
preds = []
if args.use_chat_template:
prompts = [template.render(messages=[{"role": "user", "content": prom}],bos_token= tokenizer.bos_token,add_generation_prompt=True) for prom in prompts]
if args.max_tokens > 0:
new_prompts = []
for prompt in prompts:
input_ids = tokenizer.encode(prompt,add_special_tokens= False)
if len(input_ids) > args.max_tokens:
input_ids = input_ids[:args.max_tokens]
new_prompts.append(tokenizer.decode(input_ids))
else:
new_prompts.append(prompt[-args.max_tokens:])
prompts = new_prompts
response = client.completions.create(
model="default",
prompt=prompts,
temperature=temperature, top_p=0.9, max_tokens=max_new_tokens
)
preds = [x.text for x in response.choices]
postprocessed_preds = [postprocess_output(pred) for pred in preds]
return postprocessed_preds, preds
input_data = load_file(args.eval_file)
model = None
final_results = []
if args.strict_prompt:
query_prompt = "Please answer the following multiple-choice questions. Please answer the following multiple-choice questions, ensuring your response concludes with the correct option in the format: 'The answer is A.'.\n{question}\n{option_str}"
else:
query_prompt = "Please answer the following multiple-choice question:\n{question}\n{option_str}"
for idx in tqdm(range(len(input_data) // args.batch_size + 1)):
batch = input_data[idx*args.batch_size:(idx+1)*args.batch_size]
if len(batch) == 0:
break
for item in batch:
item['option_str'] = '\n'.join([ f'{op}. {ans}' for op,ans in item['options'].items()])
item["input_str"] = query_prompt.format_map(item)
processed_batch = [ item["input_str"] for item in batch]
if idx == 0:
print_example = True
else:
print_example = False
preds, _ = call_model(
processed_batch, model=model, max_new_tokens=args.max_new_tokens, print_example=print_example)
for j, item in enumerate(batch):
pred = preds[j]
if len(pred) == 0:
continue
item["output"] = pred
final_results.append(item)
task_name = os.path.split(args.model_name)[-1]
task_name = task_name + os.path.basename(args.eval_file).replace('.json','') + f'_{args.task}' + ('_strict-prompt' if args.strict_prompt else '')
save_path = f'{task_name}.json'
with open(save_path,'w') as fw:
json.dump(final_results,fw,ensure_ascii=False,indent=2)
# get results
get_results(save_path)
if __name__ == "__main__":
main()
pkill -f sglang
pkill -f multiprocessing.spawn
\ No newline at end of file
#%%
from collections import defaultdict
import re
import json
import difflib
import os
def str_similarity(str1, str2):
seq = difflib.SequenceMatcher(None, str1, str2)
return seq.ratio()
def find_most_similar_index(str_list, target_str):
"""
Given a list of strings and a target string, returns the index of the most similar string in the list.
"""
# Initialize variables to keep track of the most similar string and its index
most_similar_str = None
most_similar_index = None
highest_similarity = 0
# Iterate through each string in the list
for i, str in enumerate(str_list):
# Calculate the similarity between the current string and the target string
similarity = str_similarity(str, target_str)
# If the current string is more similar than the previous most similar string, update the variables
if similarity >= highest_similarity:
most_similar_str = str
most_similar_index = i
highest_similarity = similarity
return most_similar_index
def match_choice(text,options):
# For HuatuoGPT-o1
if '## Final Response\n\n' in text:
text = text.split('## Final Response\n\n')[-1]
# for strict prompt
matches = list(re.finditer(r"(answer is\s*?)([A-N])", text, re.S))
if matches:
ans_first = matches[0].group(2)
ans_last = matches[-1].group(2)
return [ans_first,ans_last],1
# non strict
match_options = 'ABCDEFGHIJKLMN'[:len(options)]
matches = list(re.finditer(r"([\u4e00-\u9fff]|is |是|项|\*|\W|\ |\(|为|^|'|\"|#)(?![aA] )(["+match_options+r"])(\W|[\u4e00-\u9fff]|$)", text, re.S))
if matches:
ans_first = matches[0].group(2)
ans_last = matches[-1].group(2)
return [ans_first,ans_last],1
text = text.lower()
opsindex = [(opt,text.rindex(options[opt].lower())) for opt in options if options[opt].lower() in text]
if len(opsindex) > 0:
ans_last = sorted(opsindex,key=lambda x:x[1],reverse=True)[0][0]
opsindex = [(opt,text.index(options[opt].lower())) for opt in options if options[opt].lower() in text]
ans_first = sorted(opsindex,key=lambda x:x[1],reverse=True)[0][0]
return [ans_first,ans_last],2
else:
oplabels = [x for x in options]
opans = [options[x].lower() for x in options]
ansindex = find_most_similar_index(opans,text.lower())
return [oplabels[ansindex],oplabels[ansindex]],3
def match(prediction, ground_truth):
for gt in ground_truth:
matchres = re.search(r"(\W|^)("+re.escape(gt)+r")(\W|$)",prediction.lower(),re.S)
if matchres:
return 1
return 0
def score(data,ignore_miss= False):
res = {}
wrong_data = []
cor_data = []
for da in data:
if 'source' not in da:
da['source'] = 'unknown'
if da['source'] not in res:
res[da['source']] = [0,0,0,0]
output = da['output']
ans,ans_type = match_choice(output,da['options'])
if ignore_miss and ans_type!= 1:
continue
da['ans'] = ans
da['ans_type'] = ans_type
if ans[0].lower() == da['answer_idx'].lower():
res[da['source']][1] += 1
cor_data.append(da)
else:
wrong_data.append(da)
if ans[1].lower() == da['answer_idx'].lower():
res[da['source']][3] += 1
res[da['source']][2] += 1
for k in res:
head_match_score = res[k][1] / res[k][2]
tail_match_score = res[k][3] / res[k][2]
if head_match_score > tail_match_score:
res[k][0] = head_match_score
else:
res[k][0] = tail_match_score
return res,wrong_data,cor_data
def get_results(res_path):
with open(res_path) as f:
data = json.load(f)
res,wrong_data,cor_data = score(data)
print(f"*{os.path.basename(res_path)}*")
print(json.dumps(res,indent=4))
# save results
with open('result_' + os.path.basename(res_path),'w') as fw:
json.dump(res,fw,ensure_ascii=False,indent=2)
# if __name__ == "__main__":
# get_results('output_file_path')
\ No newline at end of file
icon.png

53.8 KB

from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "/home/modelzoo/HuatuoGPT-o1/weights/HuatuoGPT-o1-7B-Qwen"
model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype="auto",device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_text = "孩子咳嗽老不好怎么办?"
messages = [{"role": "user", "content": input_text}]
inputs = tokenizer(tokenizer.apply_chat_template(messages, tokenize=False,add_generation_prompt=True
), return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=2048)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
from vllm import LLM, SamplingParams
def inference(model_path):
messages = [
{"role": "user", "content": "孩子咳嗽老不好怎么办?"}
]
sampling_params = SamplingParams(temperature=0.1,
top_p=0.95,
max_tokens=512)
llm = LLM(model=model_path)
outputs = llm.chat(messages, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model_path", type=str)
args = parser.parse_args()
inference(args.model_path)
\ No newline at end of file
# 模型唯一标识
modelCode=1223
# 模型名称
modelName=HuatuoGPT-o1_pytorch
# 模型描述
modelDescription=医疗领域大语言模型
# 应用场景
appScenario=训练,推理,对话问答,医疗,电商,教育,广媒
# 框架类型
frameType=Pytorch
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import Optional
from trl.trainer.utils import OnPolicyConfig
@dataclass
class PPOConfig(OnPolicyConfig):
r"""
Configuration class for the [`PPOTrainer`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
Name of this experiment.
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
Path to the reward model.
model_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
num_ppo_epochs (`int`, *optional*, defaults to `4`):
Number of epochs to train.
whiten_rewards (`bool`, *optional*, defaults to `False`):
Whether to whiten the rewards.
kl_coef (`float`, *optional*, defaults to `0.05`):
KL coefficient.
cliprange (`float`, *optional*, defaults to `0.2`):
Clip range.
vf_coef (`float`, *optional*, defaults to `0.1`):
Value function coefficient.
cliprange_value (`float`, *optional*, defaults to `0.2`):
Clip range for the value function.
gamma (`float`, *optional*, defaults to `1.0`):
Discount factor.
lam (`float`, *optional*, defaults to `0.95`):
Lambda value for GAE.
"""
exp_name: str = os.path.basename(__file__)[: -len(".py")]
reward_model_path: str = "EleutherAI/pythia-160m"
value_model_path: str = "EleutherAI/pythia-160m"
model_adapter_name: Optional[str] = None
ref_adapter_name: Optional[str] = None
num_ppo_epochs: int = 4
whiten_rewards: bool = False
kl_coef: float = 0.05
cliprange: float = 0.2
vf_coef: float = 0.1
cliprange_value: float = 0.2
gamma: float = 1.0
lam: float = 0.95
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment