"...composable_kernel.git" did not exist on "e823d518cb46ad61ddb3c70eac8529e0a58af1f8"
Commit b6edc328 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2229 canceled with stages
data filter=lfs diff=lfs merge=lfs -text
MIT License
Copyright (c) 2025 Xiaoxi Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
---
license: other
license_name: qwen
license_link: https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/LICENSE
language:
- en
pipeline_tag: text-generation
base_model: Qwen/Qwen2.5-72B
tags:
- chat
---
# Qwen2.5-72B-Instruct
## Introduction
Qwen2.5 is the latest series of Qwen large language models. For Qwen2.5, we release a number of base language models and instruction-tuned language models ranging from 0.5 to 72 billion parameters. Qwen2.5 brings the following improvements upon Qwen2:
- Significantly **more knowledge** and has greatly improved capabilities in **coding** and **mathematics**, thanks to our specialized expert models in these domains.
- Significant improvements in **instruction following**, **generating long texts** (over 8K tokens), **understanding structured data** (e.g, tables), and **generating structured outputs** especially JSON. **More resilient to the diversity of system prompts**, enhancing role-play implementation and condition-setting for chatbots.
- **Long-context Support** up to 128K tokens and can generate up to 8K tokens.
- **Multilingual support** for over 29 languages, including Chinese, English, French, Spanish, Portuguese, German, Italian, Russian, Japanese, Korean, Vietnamese, Thai, Arabic, and more.
**This repo contains the instruction-tuned 72B Qwen2.5 model**, which has the following features:
- Type: Causal Language Models
- Training Stage: Pretraining & Post-training
- Architecture: transformers with RoPE, SwiGLU, RMSNorm, and Attention QKV bias
- Number of Parameters: 72.7B
- Number of Paramaters (Non-Embedding): 70.0B
- Number of Layers: 80
- Number of Attention Heads (GQA): 64 for Q and 8 for KV
- Context Length: Full 131,072 tokens and generation 8192 tokens
- Please refer to [this section](#processing-long-texts) for detailed instructions on how to deploy Qwen2.5 for handling long texts.
For more details, please refer to our [blog](https://qwenlm.github.io/blog/qwen2.5/), [GitHub](https://github.com/QwenLM/Qwen2.5), and [Documentation](https://qwen.readthedocs.io/en/latest/).
## Requirements
The code of Qwen2.5 has been in the latest Hugging face `transformers` and we advise you to use the latest version of `transformers`.
With `transformers<4.37.0`, you will encounter the following error:
```
KeyError: 'qwen2'
```
## Quickstart
Here provides a code snippet with `apply_chat_template` to show you how to load the tokenizer and model and how to generate contents.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen2.5-72B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "Give me a short introduction to large language model."
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
```
### Processing Long Texts
The current `config.json` is set for context length up to 32,768 tokens.
To handle extensive inputs exceeding 32,768 tokens, we utilize [YaRN](https://arxiv.org/abs/2309.00071), a technique for enhancing model length extrapolation, ensuring optimal performance on lengthy texts.
For supported frameworks, you could add the following to `config.json` to enable YaRN:
```json
{
...,
"rope_scaling": {
"factor": 4.0,
"original_max_position_embeddings": 32768,
"type": "yarn"
}
}
```
For deployment, we recommend using vLLM.
Please refer to our [Documentation](https://qwen.readthedocs.io/en/latest/deployment/vllm.html) for usage if you are not familar with vLLM.
Presently, vLLM only supports static YARN, which means the scaling factor remains constant regardless of input length, **potentially impacting performance on shorter texts**.
We advise adding the `rope_scaling` configuration only when processing long contexts is required.
## Evaluation & Performance
Detailed evaluation results are reported in this [📑 blog](https://qwenlm.github.io/blog/qwen2.5/).
For requirements on GPU memory and the respective throughput, see results [here](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html).
## Citation
If you find our work helpful, feel free to give us a cite.
```
@misc{qwen2.5,
title = {Qwen2.5: A Party of Foundation Models},
url = {https://qwenlm.github.io/blog/qwen2.5/},
author = {Qwen Team},
month = {September},
year = {2024}
}
@article{qwen2,
title={Qwen2 Technical Report},
author={An Yang and Baosong Yang and Binyuan Hui and Bo Zheng and Bowen Yu and Chang Zhou and Chengpeng Li and Chengyuan Li and Dayiheng Liu and Fei Huang and Guanting Dong and Haoran Wei and Huan Lin and Jialong Tang and Jialin Wang and Jian Yang and Jianhong Tu and Jianwei Zhang and Jianxin Ma and Jin Xu and Jingren Zhou and Jinze Bai and Jinzheng He and Junyang Lin and Kai Dang and Keming Lu and Keqin Chen and Kexin Yang and Mei Li and Mingfeng Xue and Na Ni and Pei Zhang and Peng Wang and Ru Peng and Rui Men and Ruize Gao and Runji Lin and Shijie Wang and Shuai Bai and Sinan Tan and Tianhang Zhu and Tianhao Li and Tianyu Liu and Wenbin Ge and Xiaodong Deng and Xiaohuan Zhou and Xingzhang Ren and Xinyu Zhang and Xipin Wei and Xuancheng Ren and Yang Fan and Yang Yao and Yichang Zhang and Yu Wan and Yunfei Chu and Yuqiong Liu and Zeyu Cui and Zhenru Zhang and Zhihao Fan},
journal={arXiv preprint arXiv:2407.10671},
year={2024}
}
```
\ No newline at end of file
# Search-o1
动态获取和整合外部知识,无需训练即可赋予开源模型CoT“慢思考”能力,属于推理版o1。
## 论文
`Search-o1: Agentic Search-Enhanced Large Reasoning Models`
- https://arxiv.org/abs/2501.05366
## 模型结构
本项目实验效果时采用Qwen2.5 作为示例,模型结构类似Llama系列,采用极简Decoder-only结构,Llama源自基本的transformer结构,主体为attention(QKV自点积)+ffn(全连接),最后外加一个softmax进行概率转换输出即可,为了使数据分布归一化方便训练收敛,在attention、ffn、softmax前分别再加一个RMS Norm。
<div align=center>
<img src="./doc/llama3.png"/>
</div>
## 算法原理
通过集成自主检索增强生成机制和文档内推理模块,实现了在推理过程中动态获取和整合外部知识的能力,同时,确保推理过程的连贯性和逻辑一致性。
<div align=center>
<img src="./doc/algorithm.png"/>
</div>
## 环境配置
```
mv Search-o1_pytorch Search-o1 # 去框架名后缀
```
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-py3.10-dtk24.04.3-ubuntu20.04
# <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:b272aae8ec72
docker run -it --shm-size=64G -v $PWD/Search-o1:/home/Search-o1 -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name searcho1 <your IMAGE ID> bash
cd /home/Search-o1
pip install -r requirements.txt
pip install whl/lmslim-0.1.2+das.dtk24043-cp310-cp310-linux_x86_64.whl # 安装lmslim==0.1.2
pip install whl/vllm-0.6.2+das.opt1.cd549d3.dtk24043-cp310-cp310-linux_x86_64.whl # 安装vllm==0.6.2
```
### Dockerfile(方法二)
```
cd cd /home/Search-o1/docker
docker build --no-cache -t searcho1:latest .
docker run --shm-size=64G --name searcho1 -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v $PWD/../../Search-o1:/home/Search-o1 -it searcho1 bash
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r requirements.txt。
cd /home/Search-o1
pip install whl/lmslim-0.1.2+das.dtk24043-cp310-cp310-linux_x86_64.whl # 安装lmslim==0.1.2
pip install whl/vllm-0.6.2+das.opt1.cd549d3.dtk24043-cp310-cp310-linux_x86_64.whl # 安装vllm==0.6.2
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
- https://developer.hpccube.com/tool/
```
DTK驱动:dtk24.04.3
python:python3.10
torch:2.3.0
torchvision:0.18.1
torchaudio:2.1.2
triton:2.1.0
vllm:0.6.2
flash-attn:2.6.1
deepspeed:0.14.2
apex:1.3.0
xformers:0.0.25
transformers:4.48.0
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应。`
2、其它非特殊库参照requirements.txt安装
```
cd /home/Search-o1
pip install -r requirements.txt
pip install whl/lmslim-0.1.2+das.dtk24043-cp310-cp310-linux_x86_64.whl # 安装lmslim==0.1.2
pip install whl/vllm-0.6.2+das.opt1.cd549d3.dtk24043-cp310-cp310-linux_x86_64.whl # 安装vllm==0.6.2
```
## 数据集
项目中提供实验性的迷你数据集`AIME``GPQA`(两个数据集的问题都难度很大)可直接使用,预处理代码参照][`data_pre_precess.ipynb`](./data/data_pre_precess.ipynb),数据集格式处理成`*.json`
数据的完整目录结构如下:
```
/home/Search-o1/data
├── AIME
├── test.json
...
├── GPQA
├── diamond.json
...
```
## 推理
### 单机多卡
**Search-o1 (Ours)**
```bash
python scripts/run_search_o1.py \
--dataset_name aime \
--split test \
--max_search_limit 5 \
--max_turn 10 \
--top_k 10 \
--max_doc_len 3000 \
--use_jina True \
--model_path "YOUR_MODEL_PATH" \
--jina_api_key "YOUR_JINA_API_KEY" \
--bing_subscription_key "YOUR_BING_SUBSCRIPTION_KEY"
```
以上命令为参考命令,无法直接运行,完整使用`Search-o1`的功能需要花钱购买jina和bing搜索引擎的api_key,如有需求可自行购买服务,额外费用与本算法无关,本步骤不含外部搜索的可用api_key,仅供使用方法示例,故无外网搜索功能:
```
sh searcho1_gen.sh # 项目中采用Qwen2.5-72B-Instruct进行示例,建议不要采用小型模型,效果不如大型模型。
# 注:项目中默认的bing_subscription_key为无效api_key,故无法获取网络数据,仅供参考。
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)
## result
由于无有效api_key,无法获取网络数据,此推理效果仅供参考示例,以方便读者了解项目使用方法:
`输入: `
[`test.json`](./data/AIME/test.json)
```
{
"id": 0,
"Problem_ID": 60,
"Question": "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.",
"Solution": "$\\frac{9}{s} + t = 4$ in hours and $\\frac{9}{s+2} + t = 2.4$ in hours.\nSubtracting the second equation from the first, we get, \n$\\frac{9}{s} - \\frac{9}{s+2} = 1.6$\nMultiplying by $(s)(s+2)$, we get \n$9s+18-9s=18=1.6s^{2} + 3.2s$\nMultiplying by 5/2 on both sides, we get\n$0 = 4s^{2} + 8s - 45$\nFactoring gives us \n$(2s-5)(2s+9) = 0$, of which the solution we want is $s=2.5$.\nSubstituting this back to the first equation, we can find that $t = 0.4$ hours.\nLastly, $s + \\frac{1}{2} = 3$ kilometers per hour, so\n$\\frac{9}{3} + 0.4 = 3.4$ hours, or $\\framebox{204}$ minutes\n-Failure.net\nThe amount of hours spent while walking on the first travel is $\\frac{240-t}{6}$. Thus, we have the equation $(240-t)(s) = 540$, and by the same logic, the second equation yields $(144-t)(s+2) = 540$. We have $240s-st = 540$, and $288+144s-2t-st = 540$. We subtract the two equations to get $96s+2t-288 = 0$, so we have $48s+t = 144$, so $t = 144-48s$, and now we have $(96+48s)(s) = 540$. The numerator of $s$ must evenly divide 540, however, $s$ must be less than 3. We can guess that $s = 2.5$. Now, $2.5+0.5 = 3$. Taking $\\frac{9}{3} = 3$, we find that it will take three hours for the 9 kilometers to be traveled. The t minutes spent at the coffeeshop can be written as $144-48(2.5)$, so t = 24. $180 + 24 = 204$. -sepehr2010",
"answer": "204"
},
...
```
参考`outputs/runs.baselines/aime.qwen2.5-72b.search_o1/test.*.json`
`输出:`
```
{
"id": 0,
"Problem_ID": 60,
"Question": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nYou are a reasoning assistant with the ability to perform web searches to help you answer the user's question accurately. You have special tools:\n\n- To perform a search: write <|begin_search_query|> your query here <|end_search_query|>.\nThen, the system will search and analyze relevant web pages, then provide you with helpful information in the format <|begin_search_result|> ...search results... <|end_search_result|>.\n\nYou can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to 5.\n\nOnce you have all the information you need, continue your reasoning.\n\nExample:\nQuestion: \"How do you compute the integral of e^(x^2) dx?\"\nAssistant thinking steps:\n- I might need to look up techniques for integrating e^(x^2).\n\nAssistant:\n<|begin_search_query|>methods to integrate e^(x^2)<|end_search_query|>\n\n(System returns processed information from relevant web pages)\n\nAssistant continues reasoning with the new information...\n\nRemember:\n- Use <|begin_search_query|> to request a web search and end with <|end_search_query|>.\n- When done searching, continue your reasoning.\n\nPlease answer the following math question. You should think step by step to solve it.\n\nProvide your final answer in the format \\boxed{YOUR_ANSWER}.\n\nQuestion:\nEvery morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.\n\n<|im_end|>\n<|im_start|>assistant\n",
"Solution": "$\\frac{9}{s} + t = 4$ in hours and $\\frac{9}{s+2} + t = 2.4$ in hours.\nSubtracting the second equation from the first, we get, \n$\\frac{9}{s} - \\frac{9}{s+2} = 1.6$\nMultiplying by $(s)(s+2)$, we get \n$9s+18-9s=18=1.6s^{2} + 3.2s$\nMultiplying by 5/2 on both sides, we get\n$0 = 4s^{2} + 8s - 45$\nFactoring gives us \n$(2s-5)(2s+9) = 0$, of which the solution we want is $s=2.5$.\nSubstituting this back to the first equation, we can find that $t = 0.4$ hours.\nLastly, $s + \\frac{1}{2} = 3$ kilometers per hour, so\n$\\frac{9}{3} + 0.4 = 3.4$ hours, or $\\framebox{204}$ minutes\n-Failure.net\nThe amount of hours spent while walking on the first travel is $\\frac{240-t}{6}$. Thus, we have the equation $(240-t)(s) = 540$, and by the same logic, the second equation yields $(144-t)(s+2) = 540$. We have $240s-st = 540$, and $288+144s-2t-st = 540$. We subtract the two equations to get $96s+2t-288 = 0$, so we have $48s+t = 144$, so $t = 144-48s$, and now we have $(96+48s)(s) = 540$. The numerator of $s$ must evenly divide 540, however, $s$ must be less than 3. We can guess that $s = 2.5$. Now, $2.5+0.5 = 3$. Taking $\\frac{9}{3} = 3$, we find that it will take three hours for the 9 kilometers to be traveled. The t minutes spent at the coffeeshop can be written as $144-48(2.5)$, so t = 24. $180 + 24 = 204$. -sepehr2010",
"answer": "204",
"Output": "To solve the problem, we need to determine the value of \\( t \\) and then use it to find the total time Aya spends walking and in the coffee shop when she walks at \\( s + \\frac{1}{2} \\) kilometers per hour.\n\nFirst, let's set up the equations based on the information given:\n\n1. When Aya walks at a speed of \\( s \\) kilometers per hour, the walk takes her 4 hours, including \\( t \\) minutes spent in the coffee shop.\n2. When Aya walks at a speed of \\( s + 2 \\) kilometers per hour, the walk takes her 2 hours and 24 minutes, including \\( t \\) minutes spent in the coffee shop.\n\n### Step 1: Convert all times to hours\n- 4 hours is already in hours.\n- 2 hours and 24 minutes can be converted to hours: \\( 2 + \\frac{24}{60} = 2.4 \\) hours.\n\n### Step 2: Set up the equations\nLet \\( T_1 \\) be the time spent walking at speed \\( s \\) and \\( T_2 \\) be the time spent walking at speed \\( s + 2 \\).\n\nFrom the first condition:\n\\[ T_1 + \\frac{t}{60} = 4 \\]\nSince \\( T_1 = \\frac{9}{s} \\):\n\\[ \\frac{9}{s} + \\frac{t}{60} = 4 \\]\n\nFrom the second condition:\n\\[ T_2 + \\frac{t}{60} = 2.4 \\]\nSince \\( T_2 = \\frac{9}{s + 2} \\):\n\\[ \\frac{9}{s + 2} + \\frac{t}{60} = 2.4 \\]\n\n### Step 3: Solve the system of equations\nWe have:\n\\[ \\frac{9}{s} + \\frac{t}{60} = 4 \\]\n\\[ \\frac{9}{s + 2} + \\frac{t}{60} = 2.4 \\]\n\nSubtract the second equation from the first:\n\\[ \\left( \\frac{9}{s} + \\frac{t}{60} \\right) - \\left( \\frac{9}{s + 2} + \\frac{t}{60} \\right) = 4 - 2.4 \\]\n\\[ \\frac{9}{s} - \\frac{9}{s + 2} = 1.6 \\]\n\n### Step 4: Simplify the equation\n\\[ 9 \\left( \\frac{1}{s} - \\frac{1}{s + 2} \\right) = 1.6 \\]\n\\[ 9 \\left( \\frac{(s + 2) - s}{s(s + 2)} \\right) = 1.6 \\]\n\\[ 9 \\left( \\frac{2}{s(s + 2)} \\right) = 1.6 \\]\n\\[ \\frac{18}{s(s + 2)} = 1.6 \\]\n\\[ 18 = 1.6s(s + 2) \\]\n\\[ 18 = 1.6s^2 + 3.2s \\]\n\\[ 1.6s^2 + 3.2s - 18 = 0 \\]\n\n### Step 5: Solve the quadratic equation\nDivide the entire equation by 1.6:\n\\[ s^2 + 2s - 11.25 = 0 \\]\n\nUse the quadratic formula \\( s = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a} \\):\n\\[ s = \\frac{-2 \\pm \\sqrt{2^2 - 4 \\cdot 1 \\cdot (-11.25)}}{2 \\cdot 1} \\]\n\\[ s = \\frac{-2 \\pm \\sqrt{4 + 45}}{2} \\]\n\\[ s = \\frac{-2 \\pm \\sqrt{49}}{2} \\]\n\\[ s = \\frac{-2 \\pm 7}{2} \\]\n\nThis gives us two solutions:\n\\[ s = \\frac{5}{2} = 2.5 \\]\n\\[ s = \\frac{-9}{2} = -4.5 \\]\n\nSince speed cannot be negative, we have:\n\\[ s = 2.5 \\]\n\n### Step 6: Find \\( t \\)\nSubstitute \\( s = 2.5 \\) back into one of the original equations:\n\\[ \\frac{9}{2.5} + \\frac{t}{60} = 4 \\]\n\\[ 3.6 + \\frac{t}{60} = 4 \\]\n\\[ \\frac{t}{60} = 0.4 \\]\n\\[ t = 0.4 \\times 60 \\]\n\\[ t = 24 \\]\n\n### Step 7: Calculate the total time when \\( s = 2.5 + 0.5 = 3 \\)\n\\[ T = \\frac{9}{3} + \\frac{24}{60} \\]\n\\[ T = 3 + 0.4 \\]\n\\[ T = 3.4 \\]\n\nConvert 3.4 hours to minutes:\n\\[ 3.4 \\times 60 = 204 \\]\n\nThus, the total time Aya spends walking and in the coffee shop when she walks at \\( s + \\frac{1}{2} \\) kilometers per hour is \\(\\boxed{204}\\) minutes.",
"Pred_Answer": "204",
"Metrics": {
"is_valid_answer": true,
"acc": 1,
"em": 1,
"f1": 1.0,
"math_equal": true
}
},
...
```
### 精度
DCU与GPU精度一致,推理框架:pytorch。
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`制造,广媒,金融,能源,医疗,家居,教育`
## 预训练权重
预训练权重快速下载中心:[SCNet AIModels](http://113.200.138.88:18080/aimodels) ,项目中的预训练权重可从快速下载通道下载:[Qwen/Qwen2.5-72B-Instruct](http://113.200.138.88:18080/aimodels/qwen/Qwen2.5-72B-Instruct.git)
Hugging Face下载地址为:[Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)
## 源码仓库及问题反馈
- http://developer.sourcefind.cn/codes/modelzoo/search-o1_pytorch.git
## 参考资料
- https://github.com/sunnynexus/Search-o1.git
<h2 align="center"> <a href="https://arxiv.org/abs/2501.05366">🔍 Search-o1: Agentic Search-Enhanced</br> Large Reasoning Models</a></h2>
<div align="center">
[![Homepage](https://img.shields.io/badge/Homepage-Search--o1-red)](https://search-o1.github.io/)
[![Paper](https://img.shields.io/badge/Paper-arXiv-b5212f.svg?logo=arxiv)](https://arxiv.org/abs/2501.05366)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg)](https://opensource.org/licenses/MIT)
[![Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/)
[![X (formerly Twitter) URL](https://img.shields.io/twitter/url?url=https%3A%2F%2Fx.com%2FKevin_GuoweiXu%2Fstatus%2F1858338565463421244)](https://x.com/_akhaliq/status/1877584951840764166?t=fnbTblnqhiPtAyYr1PHbbw&s=19)
</div>
<!-- <div align="center">
<span style="display:inline-block; margin-right: 10px;">
<a href="https://paperswithcode.com/sota/on-gpqa?p=search-o1-agentic-search-enhanced-large">
<img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/search-o1-agentic-search-enhanced-large/on-gpqa" alt="GPQA Badge">
</a>
</span>
<span style="display:inline-block; margin-right: 10px;">
<a href="https://paperswithcode.com/sota/mathematical-reasoning-on-aime24?p=search-o1-agentic-search-enhanced-large">
<img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/search-o1-agentic-search-enhanced-large/mathematical-reasoning-on-aime24" alt="AIME24 Badge">
</a>
</span>
<span style="display:inline-block; margin-right: 10px;">
<a href="https://paperswithcode.com/sota/mathematical-reasoning-on-amc23?p=search-o1-agentic-search-enhanced-large">
<img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/search-o1-agentic-search-enhanced-large/mathematical-reasoning-on-amc23" alt="AMC23 Badge">
</a>
</span>
</div> -->
<div align="center">
<span style="display:inline-block; margin-right: 10px;">
<a href="https://paperswithcode.com/sota/mathematical-reasoning-on-aime24?p=search-o1-agentic-search-enhanced-large">
<img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/search-o1-agentic-search-enhanced-large/mathematical-reasoning-on-aime24" alt="AIME24 Badge">
</a>
</span>
<span style="display:inline-block; margin-right: 10px;">
<a href="https://paperswithcode.com/sota/mathematical-reasoning-on-amc23?p=search-o1-agentic-search-enhanced-large">
<img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/search-o1-agentic-search-enhanced-large/mathematical-reasoning-on-amc23" alt="AMC23 Badge">
</a>
</span>
<span style="display:inline-block; margin-right: 10px;">
<a href="https://paperswithcode.com/sota/on-gpqa?p=search-o1-agentic-search-enhanced-large">
<img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/search-o1-agentic-search-enhanced-large/on-gpqa" alt="GPQA Badge">
</a>
</span>
</div>
<h5 align="center"> If you like our project, please give us a star ⭐ on GitHub for the latest update.</h5>
## 📣 Latest News
- **01/10/2025**: The brief introduction of our Search-O1 can be found on platforms like [X](https://x.com/_akhaliq/status/1877584951840764166?t=fnbTblnqhiPtAyYr1PHbbw&s=19), [Zhihu](https://zhuanlan.zhihu.com/p/17527068532), and [WeChat](https://mp.weixin.qq.com/s/gqnGyMM_KYYwDbHyWkIIuw).
- **01/10/2025**: The paper for Search-o1 is available. You can access it on [arxiv](https://arxiv.org/abs/2501.05366) and [HF-paper](https://huggingface.co/papers/2501.05366).
- **01/06/2025**: The homepage for Search-o1 is available. You can access it [here](https://search-o1.github.io/).
- **01/05/2025**: The code for Search-o1 has been released. You can now apply Search-o1 to enhance your large reasoning models.
## 💡 Overview
Large Reasoning Models (LRMs) like OpenAI's o1 have showcased remarkable stepwise reasoning capabilities through reinforcement learning. Despite their strengths, these models often encounter knowledge insufficiencies during prolonged reasoning processes, resulting in frequent uncertainties and potential errors, as shown in the following figure.
<p align="center">
<img src="figures/uncertainty.jpg" width="95%" />
</p>
### ✨ Method
To overcome these challenges, we present **Search-o1**, a framework that augments LRMs with an **agentic Retrieval-Augmented Generation (RAG)** mechanism and a **Reason-in-Documents** module for deep analysis and integration of retrieved documents into the reasoning chain.
- **Agentic Search Workflow**: Integrates an agentic search process into the reasoning workflow, allowing models itself to dynamically retrieve external knowledge whenever they encounter uncertain information.
- **Reason-in-Documents Module**: Seamlessly integrates the retrieved information, reducing noise and maintaining a coherent reasoning chain.
![Model Comparison](figures/compare.jpg)
### ✨ Inference Process
Search-o1 incorporates a batch generation mechanism with interleaved search. We initialize reasoning sequences by combining task instructions with input questions. It simultaneously generates tokens for all sequences, detecting search queries to retrieve relevant documents in batches. These documents are then refined and seamlessly integrated back into the reasoning chains, iterating this process until all sequences are completed and final answers are produced.
![Inference](figures/algorithm.jpg)
This approach enhances the reliability and accuracy of LRMs, enabling them to handle complex reasoning tasks more effectively by addressing knowledge gaps in real-time.
## 🔧 Installation
### 1. Environment Setup
```bash
# Create conda environment
conda create -n search_o1 python=3.9
conda activate search_o1
# Install requirements
cd Search-o1
pip install -r requirements.txt
```
## 🏃 Quick Start
### Data Preparation
Use the code provided in `data/data_pre_process.ipynb` to preprocess each dataset into our standardized JSON format. The datasets we utilize are categorized into two types:
- **Challenging Reasoning Tasks:**
- **PhD-level Science QA:** GPQA
- **Math Benchmarks:** MATH500, AMC2023, AIME2024
- **Code Benchmark:** LiveCodeBench
- **Open-domain QA Tasks:**
- **Single-hop QA:** NQ, TriviaQA
- **Multi-hop QA:** HotpotQA, 2WikiMultihopQA, MuSiQue, Bamboogle
To preprocess the datasets, follow these steps:
1. Open the Jupyter notebook `data/data_pre_process.ipynb`.
2. For each dataset, run the corresponding preprocessing cells to convert the raw data into the unified JSON format.
3. The processed datasets will be saved in the `data/` directory.
### Model Inference
You can run different inference modes using the provided scripts. Below are examples of how to execute each mode:
1. **Direct Reasoning (Direct Generation)**
```bash
python scripts/run_direct_gen.py \
--dataset_name gpqa \
--split diamond \
--model_path "YOUR_MODEL_PATH"
```
2. **Naive Retrieval-Augmented Generation (RAG)**
```bash
python scripts/run_naive_rag.py \
--dataset_name gpqa \
--split diamond \
--use_jina True \
--model_path "YOUR_MODEL_PATH" \
--jina_api_key "YOUR_JINA_API_KEY" \
--bing_subscription_key "YOUR_BING_SUBSCRIPTION_KEY"
```
3. **RAG with Agentic Search**
```bash
python scripts/run_rag_agent.py \
--dataset_name gpqa \
--split diamond \
--max_search_limit 5 \
--max_url_fetch 5 \
--max_turn 10 \
--top_k 10 \
--use_jina True \
--model_path "YOUR_MODEL_PATH" \
--jina_api_key "YOUR_JINA_API_KEY" \
--bing_subscription_key "YOUR_BING_SUBSCRIPTION_KEY"
```
4. **Search-o1 (Ours)**
```bash
python scripts/run_search_o1.py \
--dataset_name aime \
--split test \
--max_search_limit 5 \
--max_turn 10 \
--top_k 10 \
--max_doc_len 3000 \
--use_jina True \
--model_path "YOUR_MODEL_PATH" \
--jina_api_key "YOUR_JINA_API_KEY" \
--bing_subscription_key "YOUR_BING_SUBSCRIPTION_KEY"
```
**Parameters Explanation:**
- `--dataset_name`: Name of the dataset to use (e.g., gpqa, aime).
- `--split`: Data split to run (e.g., train, test, diamond).
- `--model_path`: Path to the pre-trained LRM model.
- `--bing_subscription_key`: Your Bing Search API subscription key.
- `--max_search_limit`: Maximum number of search queries per reasoning session.
- `--max_url_fetch`: Maximum number of URLs to fetch per search.
- `--max_turn`: Maximum number of reasoning turns.
- `--top_k`: Number of top documents to retrieve.
- `--max_doc_len`: Maximum length of each retrieved document.
- `--use_jina`: Whether to use Jina for document processing.
- `--jina_api_key`: Your Jina API subscription key for URL content fetching.
Ensure you replace `"YOUR_MODEL_PATH"` with your actual model path, replace `"YOUR_BING_SUBSCRIPTION_KEY"` and `"YOUR_JINA_API_KEY"` with your Bing Search and Jina API key.
### Evaluation
Our model inference scripts will automatically save the model's input and output texts for evaluation. However, for methods with retrieval, since the model has not been trained to use the retrieved text effectively, it often fails to provide a final answer. We apply a backoff strategy to use the direct generation result when the retrieval-based methods do not provide a final answer for a given data point.
To use this backoff strategy, you need to provide the path to the direct generation results in the `scripts/evaluate.py` file, and then use the following command to get the backoff results for retrieval-based methods:
```bash
python scripts/evaluate.py \
--output_path outputs/... \
--apply_backoff
```
## 📄 Citation
If you find this work helpful, please cite our paper:
```bibtex
@article{search-o1,
title={Search-o1: Agentic Search-Enhanced Large Reasoning Models},
author={Xiaoxi Li and
Guanting Dong and
Jiajie Jin and
Yuyao Zhang and
Yujia Zhou and
Yutao Zhu and
Peitian Zhang and
Zhicheng Dou},
journal={CoRR},
volume={abs/2501.05366},
year={2025},
url={https://arxiv.org/abs/2501.05366},
eprinttype={arXiv},
eprint={2501.05366}
}
```
## 📄 License
This project is released under the [MIT License](LICENSE).
## 📞 Contact
For any questions or feedback, please reach out to us at [xiaoxi_li@ruc.edu.cn](xiaoxi_li@ruc.edu.cn).
---
© 2025 Search-o1 Team. All rights reserved.
This diff is collapsed.
---
license: cc-by-4.0
viewer: true
extra_gated_prompt: >-
You agree to NOT reveal examples from this dataset in plain text or images
online, to reduce the risk of leakage into foundation model training corpora.
extra_gated_fields:
I accept these terms: checkbox
configs:
- config_name: gpqa_extended
data_files: gpqa_extended.csv
- config_name: gpqa_main
data_files: gpqa_main.csv
- config_name: gpqa_diamond
data_files: gpqa_diamond.csv
- config_name: gpqa_experts
data_files: gpqa_experts.csv
task_categories:
- question-answering
- text-generation
language:
- en
tags:
- open-domain-qa
- open-book-qa
- multiple-choice-qa
pretty_name: GPQA
size_categories:
- n<1K
---
# Dataset Card for GPQA
<!-- Provide a quick summary of the dataset. -->
GPQA is a multiple-choice, Q&A dataset of very hard questions written and validated by experts in biology, physics, and chemistry. When attempting questions out of their own domain (e.g., a physicist answers a chemistry question), these experts get only 34% accuracy, despite spending >30m with full access to Google.
We request that you **do not reveal examples from this dataset in plain text or images online**, to reduce the risk of leakage into foundation model training corpora.
## Dataset Details
### Dataset Description
<!-- Provide a longer summary of what this dataset is. -->
We present GPQA, a challenging dataset of 448 multiple-choice questions written by domain experts in biology, physics, and chemistry. We ensure that the questions are high-quality and extremely difficult: experts who have or are pursuing PhDs in the corresponding domains reach 65% accuracy (74% when discounting clear mistakes the experts identified in retrospect), while highly skilled non-expert validators only reach 34% accuracy, despite spending on average over 30 minutes with unrestricted access to the web (i.e., the questions are "Google-proof"). The questions are also difficult for state-of-the-art AI systems, with our strongest GPT-4 based baseline achieving 39% accuracy. If we are to use future AI systems to help us answer very hard questions, for example, when developing new scientific knowledge, we need to develop scalable oversight methods that enable humans to supervise their outputs, which may be difficult even if the supervisors are themselves skilled and knowledgeable. The difficulty of GPQA both for skilled non-experts and frontier AI systems should enable realistic scalable oversight experiments, which we hope can help devise ways for human experts to reliably get truthful information from AI systems that surpass human capabilities.
- **Curated by:** David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
- **License:** CC BY 4.0
### Dataset Sources
<!-- Provide the basic links for the dataset. -->
- **Repository:** https://github.com/idavidrein/gpqa
- **Paper:** https://arxiv.org/abs/2311.12022
## Uses
The dataset is primarily intended to be used for scalable oversight experiments, although it can also be used for more general LLM capabilities benchmarking.
## Dataset Card Contact
David Rein: idavidrein@gmail.com
---
Submit corrections to examples in GPQA via this form: https://forms.gle/iTY4zMETNsPhJq8R9
---
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
GPQA (c) by Irving David Rein
GPQA is licensed under a Creative Commons Attribution 4.0 International License.
You should have received a copy of the license along with this
work. If not, see <https://creativecommons.org/licenses/by/4.0/>.
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for GPQA\n",
"import csv\n",
"import json\n",
"import random\n",
"from tqdm import tqdm\n",
"\n",
"# Paths to data\n",
"data_path = './GPQA/original_data/gpqa_extended.csv'\n",
"output_path = './GPQA/extended.json'\n",
"\n",
"# Define the keys we want to keep\n",
"keys_to_keep = [\n",
" 'id',\n",
" 'Question',\n",
" 'Subdomain',\n",
" 'High-level domain',\n",
" 'Correct Answer',\n",
" 'Incorrect Answer 1',\n",
" 'Incorrect Answer 2',\n",
" 'Incorrect Answer 3'\n",
"]\n",
"\n",
"filtered_data = []\n",
"with open(data_path, mode='r', encoding='utf-8') as csv_file:\n",
" csv_reader = csv.DictReader(csv_file)\n",
" for idx, row in enumerate(tqdm(csv_reader), 0):\n",
" # Add id field\n",
" row['id'] = idx\n",
" # Create new dictionary with only desired keys\n",
" filtered_row = {key: row[key] for key in keys_to_keep}\n",
"\n",
" # Extract answers and shuffle them\n",
" answers = [\n",
" ('Correct Answer', filtered_row['Correct Answer']),\n",
" ('Incorrect Answer 1', filtered_row['Incorrect Answer 1']),\n",
" ('Incorrect Answer 2', filtered_row['Incorrect Answer 2']),\n",
" ('Incorrect Answer 3', filtered_row['Incorrect Answer 3'])\n",
" ]\n",
" random.shuffle(answers)\n",
"\n",
" # Assign new choices A, B, C, D in order and determine the correct choice\n",
" choices = ['A', 'B', 'C', 'D']\n",
" formatted_answers = []\n",
" correct_choice = None\n",
" for i, (label, answer) in enumerate(answers):\n",
" choice = choices[i]\n",
" formatted_answers.append((choice, answer))\n",
" if label == 'Correct Answer':\n",
" correct_choice = choice\n",
"\n",
" # Update the Question field\n",
" formatted_choices = \"\\n\".join([f\"({choice}) {answer}\" for choice, answer in formatted_answers])\n",
" filtered_row['Question'] = f\"{filtered_row['Question']} Choices:\\n{formatted_choices}\\n\"\n",
"\n",
" # Add the Correct Choice field\n",
" filtered_row['Correct Choice'] = correct_choice\n",
"\n",
" # Append the updated row to filtered_data\n",
" filtered_data.append(filtered_row)\n",
"\n",
"# Write the updated data to JSON\n",
"with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(filtered_data, json_file, indent=4, ensure_ascii=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for MATH500\n",
"import csv\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"test_path = './MATH500/original_data/test.jsonl'\n",
"output_path = './MATH500/test.json'\n",
"\n",
"data_list = []\n",
"with open(test_path, 'r') as file:\n",
" for id, line in enumerate(file.readlines()):\n",
" line = json.loads(line)\n",
" data_list.append({\n",
" 'id': id, \n",
" 'Question': line['problem'],\n",
" 'solution': line['solution'],\n",
" 'answer': line['answer'],\n",
" 'subject': line['subject'],\n",
" 'level': line['level'],\n",
" 'unique_id': line['unique_id'],\n",
" })\n",
"\n",
"# Write the updated data to JSON\n",
"with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(data_list, json_file, indent=4, ensure_ascii=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for AIME\n",
"import csv\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"test_path = './AIME/original_data/aime_2024.json'\n",
"output_path = './AIME/2024.json'\n",
"\n",
"data_list = []\n",
"with open(test_path, 'r') as file:\n",
" data = json.load(file)\n",
" for id, line in enumerate(tqdm(data)):\n",
" data_list.append({\n",
" 'id': id, \n",
" 'Problem_ID': line['ID'],\n",
" 'Question': line['Problem'],\n",
" 'Solution': line['Solution'],\n",
" 'answer': str(line['Answer']),\n",
" })\n",
"\n",
"# Write the updated data to JSON\n",
"with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(data_list, json_file, indent=4, ensure_ascii=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for AMC\n",
"import csv\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"test_path = './AMC/original_data/amc_2022_2023.json'\n",
"output_path = './AMC/test.json'\n",
"\n",
"data_list = []\n",
"with open(test_path, 'r') as file:\n",
" data = json.load(file)\n",
" id = 0\n",
" for line in tqdm(data):\n",
" if '2023' not in line['url']:\n",
" continue\n",
" data_list.append({\n",
" 'id': id, \n",
" 'Question': line['problem'],\n",
" 'answer': str(int(line['answer'])),\n",
" 'url': line['url'],\n",
" })\n",
" id += 1\n",
"\n",
"# Write the updated data to JSON\n",
"with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(data_list, json_file, indent=4, ensure_ascii=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for LiveCodeBench\n",
"import json\n",
"from tqdm import tqdm\n",
"from datetime import datetime\n",
"\n",
"def is_valid_date(date_str):\n",
" \"\"\"\n",
" Check if the given date string is within the range from August 1, 2024, to November 30, 2024.\n",
"\n",
" Args:\n",
" date_str (str): The date string in the format \"%Y-%m-%dT%H:%M:%S\".\n",
"\n",
" Returns:\n",
" bool: True if the date is within the specified range, False otherwise.\n",
" \"\"\"\n",
" try:\n",
" # Parse the date string into a datetime object\n",
" date = datetime.strptime(date_str, \"%Y-%m-%dT%H:%M:%S\")\n",
" except ValueError:\n",
" # If the date string is not in the expected format, consider it invalid\n",
" return False\n",
"\n",
" # Define the start and end dates for the valid range\n",
" start_date = datetime(2024, 8, 1)\n",
" end_date = datetime(2024, 11, 30)\n",
"\n",
" # Check if the date falls within the valid range\n",
" return start_date <= date <= end_date\n",
"\n",
"# Define the paths to the input JSONL files\n",
"test_paths = [\n",
" './LiveCodeBench/test.jsonl',\n",
" './LiveCodeBench/test2.jsonl',\n",
" './LiveCodeBench/test3.jsonl',\n",
" './LiveCodeBench/test4.jsonl'\n",
"]\n",
"\n",
"# Define the path to the output JSON file\n",
"output_path = './LiveCodeBench/test.json'\n",
"\n",
"data_list = []\n",
"seen_questions = set() # To track unique questions based on 'question_content'\n",
"current_id = 0 # To assign unique IDs across all files\n",
"\n",
"for test_path in test_paths:\n",
" try:\n",
" with open(test_path, 'r', encoding='utf-8') as file:\n",
" # Use tqdm to show progress; total can be estimated if needed\n",
" for line in tqdm(file, desc=f'Processing {test_path}'):\n",
" try:\n",
" # Parse the JSON line\n",
" line_data = json.loads(line)\n",
" except json.JSONDecodeError:\n",
" # Skip lines that are not valid JSON\n",
" continue\n",
"\n",
" # Check if the 'contest_date' field exists and is valid\n",
" contest_date = line_data.get('contest_date')\n",
" if not contest_date or not is_valid_date(contest_date):\n",
" continue\n",
"\n",
" # Get the question content to check for duplicates\n",
" question_content = line_data.get('question_content')\n",
" if not question_content:\n",
" continue # Skip if 'question_content' is missing\n",
"\n",
" if question_content in seen_questions:\n",
" continue # Duplicate question; skip\n",
"\n",
" # Add the question to the seen set\n",
" seen_questions.add(question_content)\n",
"\n",
" # Append the question data to the list\n",
" data_list.append({\n",
" 'id': current_id,\n",
" 'Question': question_content,\n",
" 'question_title': line_data.get('question_title', ''),\n",
" 'contest_date': contest_date,\n",
" 'difficulty': line_data.get('difficulty', ''),\n",
" 'public_test_cases': line_data.get('public_test_cases', [])\n",
" })\n",
"\n",
" current_id += 1 # Increment the unique ID\n",
"\n",
" except FileNotFoundError:\n",
" print(f\"File not found: {test_path}\")\n",
" except Exception as e:\n",
" print(f\"An error occurred while processing {test_path}: {e}\")\n",
"\n",
"# Write the aggregated and deduplicated data to the output JSON file\n",
"try:\n",
" with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(data_list, json_file, indent=4, ensure_ascii=False)\n",
" print(f\"Data successfully written to {output_path}\")\n",
"except Exception as e:\n",
" print(f\"Failed to write data to {output_path}: {e}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for FlashRAG ODQA datasets\n",
"import csv\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"dataset_name = 'bamboogle'\n",
"split = 'test'\n",
"data_num = 200\n",
"\n",
"test_path = f'./FlashRAG_datasets/{dataset_name}/{split}.jsonl'\n",
"output_path = f'./QA_Datasets/{dataset_name}.json'\n",
"\n",
"data_list = []\n",
"with open(test_path, 'r') as file:\n",
" for id, line in enumerate(tqdm(file.readlines())):\n",
" line = json.loads(line)\n",
" data_list.append({\n",
" 'id': id, \n",
" 'Question': line['question'],\n",
" 'answer': line[\"golden_answers\"],\n",
" })\n",
" if len(data_list) >= data_num:\n",
" break\n",
"\n",
"# Write the updated data to JSON\n",
"with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(data_list, json_file, indent=4, ensure_ascii=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Data preprocess for FlashRAG ODQA datasets (All)\n",
"import csv\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"dataset_name = 'musique'\n",
"split = 'dev'\n",
"data_num = 100000\n",
"\n",
"test_path = f'./FlashRAG_datasets/{dataset_name}/{split}.jsonl'\n",
"output_path = f'./QA_Datasets/{dataset_name}.json'\n",
"\n",
"data_list = []\n",
"with open(test_path, 'r') as file:\n",
" for id, line in enumerate(tqdm(file.readlines())):\n",
" line = json.loads(line)\n",
" data_list.append({\n",
" 'id': id, \n",
" 'Question': line['question'],\n",
" 'answer': line[\"golden_answers\"],\n",
" })\n",
" if len(data_list) >= data_num:\n",
" break\n",
"\n",
"# Write the updated data to JSON\n",
"with open(output_path, mode='w', encoding='utf-8') as json_file:\n",
" json.dump(data_list, json_file, indent=4, ensure_ascii=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# Convert Parquet to JSON (AIME)
import pandas as pd
'''
# Specify the Parquet file path
# link: https://huggingface.co/datasets/AI-MO/aimo-validation-aime
parquet_file = "./train-00000-of-00001.parquet"
# Use pandas to read the Parquet file
df = pd.read_parquet(parquet_file)
# Filter the DataFrame to keep only rows where '2024_AIME' appears in the 'url' column
filtered_df = df[df['url'].str.contains('2024_AIME', na=False)]
# Print the first few rows of the filtered DataFrame to confirm
print(filtered_df.head())
# Export to a JSON file with indentation
json_file = "./aime_2024.json"
filtered_df.to_json(json_file, orient='records', force_ascii=False, indent=4)
print(f"Filtered data has been saved to {json_file}")
'''
# Data preprocess for AIME
import csv
import json
from tqdm import tqdm
test_path = './aime_2024.json'
output_path = './test.json'
data_list = []
with open(test_path, 'r', encoding='utf-8') as file:
data = json.load(file)
for id, line in enumerate(tqdm(data)):
data_list.append({
'id': id,
'Question': line['problem'],
'Solution': line['solution'],
'answer': str(int(line['answer'])),
})
# Write the updated data to JSON
with open(output_path, mode='w', encoding='utf-8') as json_file:
json.dump(data_list, json_file, indent=4, ensure_ascii=False)
# Data preprocess for GPQA
import csv
import json
import random
from tqdm import tqdm
# Paths to data
data_path = './GPQA/gpqa_extended.csv'
output_path = './GPQA/diamond.json'
# Define the keys we want to keep
keys_to_keep = [
'id',
'Question',
'Subdomain',
'High-level domain',
'Correct Answer',
'Incorrect Answer 1',
'Incorrect Answer 2',
'Incorrect Answer 3'
]
filtered_data = []
with open(data_path, mode='r', encoding='utf-8') as csv_file:
csv_reader = csv.DictReader(csv_file)
for idx, row in enumerate(tqdm(csv_reader), 0):
# Add id field
row['id'] = idx
# Create new dictionary with only desired keys
filtered_row = {key: row[key] for key in keys_to_keep}
# Extract answers and shuffle them
answers = [
('Correct Answer', filtered_row['Correct Answer']),
('Incorrect Answer 1', filtered_row['Incorrect Answer 1']),
('Incorrect Answer 2', filtered_row['Incorrect Answer 2']),
('Incorrect Answer 3', filtered_row['Incorrect Answer 3'])
]
random.shuffle(answers)
# Assign new choices A, B, C, D in order and determine the correct choice
choices = ['A', 'B', 'C', 'D']
formatted_answers = []
correct_choice = None
for i, (label, answer) in enumerate(answers):
choice = choices[i]
formatted_answers.append((choice, answer))
if label == 'Correct Answer':
correct_choice = choice
# Update the Question field
formatted_choices = "\n".join([f"({choice}) {answer}" for choice, answer in formatted_answers])
filtered_row['Question'] = f"{filtered_row['Question']} Choices:\n{formatted_choices}\n"
# Add the Correct Choice field
filtered_row['Correct Choice'] = correct_choice
# Append the updated row to filtered_data
filtered_data.append(filtered_row)
# Write the updated data to JSON
with open(output_path, mode='w', encoding='utf-8') as json_file:
json.dump(filtered_data, json_file, indent=4, ensure_ascii=False)
python scripts/run_direct_gen.py \
--dataset_name gpqa \
--split diamond \
--model_path "Qwen2.5-72B-Instruct"
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-py3.10-dtk24.04.3-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
# RUN source /opt/dtk-24.04.3/env.sh
# # 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment