"...git@developer.sourcefind.cn:modelzoo/co-detr_mmcv.git" did not exist on "4a95060217c387e697d83f5d797d5a75aab2ec9f"
Commit b6edc328 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2229 canceled with stages
# run_rag_agent.py
import os
import json
import time
import re
import requests
from tqdm import tqdm
import numpy as np
import torch
from typing import Optional, Tuple, List, Dict
import argparse
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from bing_search import bing_web_search, extract_relevant_info, fetch_page_content
from evaluate import run_evaluation
from prompts import (
get_singleqa_rag_agent_instruction,
get_multiqa_rag_agent_instruction,
get_gpqa_rag_agent_instruction,
get_math_rag_agent_instruction,
get_code_rag_agent_instruction,
get_task_instruction_openqa,
get_task_instruction_math,
get_task_instruction_multi_choice,
get_task_instruction_code,
)
# Define special symbols
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
END_SEARCH_QUERY = "<|end_search_query|>"
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
END_SEARCH_RESULT = "<|end_search_result|>"
BEGIN_URL = "<|begin_url|>"
END_URL = "<|end_url|>"
BEGIN_FULL_PAGE = "<|begin_full_page|>"
END_FULL_PAGE = "<|end_full_page|>"
def parse_args():
parser = argparse.ArgumentParser(description="Run RAG Agent for various datasets and models.")
# Dataset and split configuration
parser.add_argument(
'--dataset_name',
type=str,
required=True,
choices=['gpqa', 'math500', 'aime', 'amc', 'livecode', 'nq', 'triviaqa', 'hotpotqa', '2wiki', 'musique', 'bamboogle', 'medmcqa', 'pubhealth'],
help="Name of the dataset to use."
)
parser.add_argument(
'--split',
type=str,
required=True,
choices=['test', 'diamond', 'main', 'extended'],
help="Dataset split to use."
)
parser.add_argument(
'--subset_num',
type=int,
default=-1,
help="Number of examples to process. Defaults to all if not specified."
)
# RAG Agent configuration
parser.add_argument(
'--max_search_limit',
type=int,
default=5,
help="Maximum number of searches per question."
)
parser.add_argument(
'--max_url_fetch',
type=int,
default=5,
help="Maximum number of URL fetches per question."
)
parser.add_argument(
'--max_turn',
type=int,
default=10,
help="Maximum number of turns."
)
parser.add_argument(
'--top_k',
type=int,
default=10,
help="Maximum number of search documents to return."
)
# Model configuration
parser.add_argument(
'--model_path',
type=str,
required=True,
help="Path to the pre-trained model."
)
parser.add_argument(
'--use_jina',
type=bool,
default=True,
help="Whether to use Jina API for document fetching."
)
parser.add_argument(
'--jina_api_key',
type=str,
default='None',
help="Your Jina API Key to Fetch URL Content."
)
# Sampling parameters
parser.add_argument(
'--temperature',
type=float,
default=0.7,
help="Sampling temperature."
)
parser.add_argument(
'--top_p',
type=float,
default=0.8,
help="Top-p sampling parameter."
)
parser.add_argument(
'--top_k_sampling',
type=int,
default=20,
help="Top-k sampling parameter."
)
parser.add_argument(
'--repetition_penalty',
type=float,
default=None,
help="Repetition penalty. If not set, defaults based on the model."
)
parser.add_argument(
'--max_tokens',
type=int,
default=32768,
help="Maximum number of tokens to generate. If not set, defaults based on the model and dataset."
)
# Bing API Configuration
parser.add_argument(
'--bing_subscription_key',
type=str,
required=True,
help="Bing Search API subscription key."
)
parser.add_argument(
'--bing_endpoint',
type=str,
default="https://api.bing.microsoft.com/v7.0/search",
help="Bing Search API endpoint."
)
return parser.parse_args()
def main():
args = parse_args()
# Extract arguments
dataset_name = args.dataset_name
split = args.split
subset_num = args.subset_num
MAX_SEARCH_LIMIT = args.max_search_limit
MAX_URL_FETCH = args.max_url_fetch
MAX_TURN = args.max_turn
top_k = args.top_k
model_path = args.model_path
temperature = args.temperature
top_p = args.top_p
top_k_sampling = args.top_k_sampling
repetition_penalty = args.repetition_penalty
max_tokens = args.max_tokens
bing_subscription_key = args.bing_subscription_key
bing_endpoint = args.bing_endpoint
use_jina = args.use_jina
jina_api_key = args.jina_api_key
# Adjust parameters based on dataset
if dataset_name in ['nq', 'triviaqa', 'hotpotqa', 'musique', 'bamboogle', '2wiki', 'medmcqa', 'pubhealth']:
MAX_SEARCH_LIMIT = 5
MAX_URL_FETCH = 5
MAX_TURN = 10
top_k = 5
# Set default repetition_penalty if not provided
if repetition_penalty is None:
repetition_penalty = 1.05 if 'qwq' in model_path.lower() else 1.0
if args.jina_api_key == 'None':
jina_api_key = None
# Data paths based on dataset
if dataset_name == 'livecode':
data_path = f'./data/LiveCodeBench/{split}.json'
elif dataset_name in ['math500', 'gpqa', 'aime', 'amc']:
data_path = f'./data/{dataset_name.upper()}/{split}.json'
else:
data_path = f'./data/QA_Datasets/{dataset_name}.json'
# ---------------------- Caching Mechanism ----------------------
# Define cache directories and file paths
cache_dir = './cache'
search_cache_path = os.path.join(cache_dir, 'search_cache.json')
url_cache_path = os.path.join(cache_dir, 'url_cache.json')
# Ensure cache directory exists
os.makedirs(cache_dir, exist_ok=True)
# Load existing caches or initialize empty dictionaries
if os.path.exists(search_cache_path):
with open(search_cache_path, 'r', encoding='utf-8') as f:
search_cache = json.load(f)
else:
search_cache = {}
if os.path.exists(url_cache_path):
with open(url_cache_path, 'r', encoding='utf-8') as f:
url_cache = json.load(f)
else:
url_cache = {}
# Define function to save caches
def save_caches():
with open(search_cache_path, 'w', encoding='utf-8') as f:
json.dump(search_cache, f, ensure_ascii=False, indent=2)
with open(url_cache_path, 'w', encoding='utf-8') as f:
json.dump(url_cache, f, ensure_ascii=False, indent=2)
# ---------------------- Model Loading ----------------------
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Define output directory based on model and dataset
if 'qwq' in model_path.lower():
if dataset_name in ['math500', 'gpqa', 'aime', 'amc', 'livecode']:
output_dir = f'./outputs/{dataset_name}.qwq.rag_agent'
else:
output_dir = f'./outputs/runs.qa/{dataset_name}.qwq.rag_agent'
else:
model_short_name = model_path.split('/')[-1].lower().replace('-instruct', '')
output_dir = f'./outputs/runs.baselines/{dataset_name}.{model_short_name}.rag_agent'
os.makedirs(output_dir, exist_ok=True)
# Initialize the LLM
llm = LLM(
model=model_path,
tensor_parallel_size=torch.cuda.device_count(),
gpu_memory_utilization=0.95,
)
# ---------------------- Data Loading ----------------------
with open(data_path, 'r', encoding='utf-8') as json_file:
filtered_data = json.load(json_file)
# ---------------------- Prepare Input ----------------------
input_list = []
for item in filtered_data:
question = item['Question']
if dataset_name in ['nq', 'triviaqa', 'hotpotqa', 'musique', 'bamboogle', '2wiki']:
if dataset_name in ['nq', 'triviaqa']:
instruction = get_singleqa_rag_agent_instruction(MAX_SEARCH_LIMIT, MAX_URL_FETCH)
elif dataset_name in ['hotpotqa', 'musique', 'bamboogle', '2wiki']:
instruction = get_multiqa_rag_agent_instruction(MAX_SEARCH_LIMIT, MAX_URL_FETCH)
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_openqa(question, model_name='qwq')
else:
user_prompt = get_task_instruction_openqa(question)
elif dataset_name in ['math500', 'aime', 'amc']:
instruction = get_math_rag_agent_instruction(MAX_SEARCH_LIMIT, MAX_URL_FETCH)
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_math(question, model_name='qwq')
else:
user_prompt = get_task_instruction_math(question)
elif dataset_name == 'gpqa':
instruction = get_gpqa_rag_agent_instruction(MAX_SEARCH_LIMIT, MAX_URL_FETCH)
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_multi_choice(question, model_name='qwq')
elif 'llama' in model_path.lower():
user_prompt = get_task_instruction_multi_choice(question, model_name='llama')
else:
user_prompt = get_task_instruction_multi_choice(question)
elif dataset_name == 'livecode':
instruction = get_code_rag_agent_instruction(MAX_SEARCH_LIMIT, MAX_URL_FETCH)
question_title = item.get('question_title', '')
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_code(question, question_title=question_title, model_name='qwq')
else:
user_prompt = get_task_instruction_code(question)
else:
user_prompt = "" # Default to empty if dataset not matched
prompt = [{"role": "user", "content": instruction + user_prompt}]
prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
input_list.append(prompt)
if subset_num != -1:
input_list = input_list[:subset_num]
filtered_data = filtered_data[:subset_num]
# Initialize active sequences with search and URL fetch counters
active_sequences = [{
'item': item,
'prompt': prompt,
'output': '',
'finished': False,
'history': [],
'pending_operations': [], # Queue of operations to execute
'executed_search_queries': set(),
'executed_url_fetches': set(),
'search_count': 0 # Search counter
} for item, prompt in zip(filtered_data, input_list)]
# ---------------------- Set Max Tokens ----------------------
if 'qwq' in model_path.lower():
if dataset_name in ['aime', 'amc', 'livecode']:
max_tokens = 32768
else:
max_tokens = 20480
else:
max_tokens = 8192
# ---------------------- Generation Function ----------------------
def run_generation(sequences, max_tokens):
"""
Run LLM generation on provided sequences.
"""
prompts = [s['prompt'] for s in sequences]
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k_sampling,
repetition_penalty=repetition_penalty,
stop=[END_SEARCH_QUERY, END_URL, tokenizer.eos_token],
include_stop_str_in_output=True,
)
output_list = llm.generate(prompts, sampling_params=sampling_params)
return output_list
# Function to extract text between two tags
def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
matches = re.findall(pattern, text, flags=re.DOTALL)
if matches:
return matches[-1].strip()
return None
start_time = time.time()
turn = 0
# Main loop until all sequences are completed
while True:
# Separate sequences with pending operations and those needing generation
sequences_with_pending_ops = [seq for seq in active_sequences if not seq['finished'] and seq['pending_operations']]
sequences_needing_generation = [seq for seq in active_sequences if not seq['finished'] and not seq['pending_operations']]
# First, handle pending operations
if sequences_with_pending_ops:
print(f"{len(sequences_with_pending_ops)} sequences have pending operations. Executing...")
for seq in sequences_with_pending_ops:
# Execute the next pending operation
operation = seq['pending_operations'].pop(0) # FIFO
op_type = operation['type']
content = operation['content']
if op_type == 'search':
query = content
if query in search_cache:
results = search_cache[query]
print(f"Using cached search results for query: {query}")
else:
try:
# Execute search and cache results
results = bing_web_search(query, bing_subscription_key, bing_endpoint, market='en-US', language='en')
search_cache[query] = results
print(f"Executed and cached search for query: {query}")
except Exception as e:
print(f"Error during search query '{query}': {e}")
search_cache[query] = {}
results = {}
relevant_info = extract_relevant_info(results)[:top_k]
search_result_str = json.dumps(relevant_info, ensure_ascii=False, indent=2)
# Append search results to the prompt
append_text = f"\n{BEGIN_SEARCH_RESULT}\n{search_result_str}\n{END_SEARCH_RESULT}\n"
seq['prompt'] += append_text
seq['output'] += append_text
# Update history
seq['history'].append(append_text)
# Increment search count
seq['search_count'] += 1
elif op_type == 'fetch_url':
urls = content
# Calculate remaining URL fetches
remaining_fetches = MAX_URL_FETCH - len(seq['executed_url_fetches'])
if remaining_fetches <= 0:
# Reached URL fetch limit, add limit message and mark sequence as finished
limit_message = f"\n{BEGIN_FULL_PAGE}\nThe maximum number of URL fetches has been reached. You are not allowed to fetch more URLs.\n{END_FULL_PAGE}\n"
seq['prompt'] += limit_message
seq['output'] += limit_message
seq['history'].append(limit_message)
print("Reached URL fetch limit. Sequence marked as finished.")
continue
# Split and clean URLs
urls_to_fetch = [u.strip() for u in urls.split(",")]
# Filter already fetched URLs
urls_to_fetch = [u for u in urls_to_fetch if u not in seq['executed_url_fetches']]
# Limit the number of URLs to fetch
urls_to_fetch = urls_to_fetch[:remaining_fetches]
if not urls_to_fetch:
print("All requested URLs have been fetched or exceeded the limit.")
continue
# Batch fetch page content, considering cache
urls_to_fetch_filtered = [u for u in urls_to_fetch if u not in url_cache]
cached_urls = [u for u in urls_to_fetch if u in url_cache]
fetched_contents = []
# Use cached URL content
for url in cached_urls:
content = url_cache[url]
print(f"Using cached URL content for URL: {url}")
fetched_contents.append((url, content))
# Batch fetch uncached URLs
if urls_to_fetch_filtered:
try:
# Batch pass uncached URLs
contents = fetch_page_content(urls_to_fetch_filtered, use_jina=use_jina, jina_api_key=jina_api_key)
for url, content in contents.items():
url_cache[url] = content
print(f"Fetched and cached URL content for URL: {url}")
fetched_contents.append((url, content))
except Exception as e:
for url in urls_to_fetch_filtered:
content = f"Error fetching URL: {e}"
url_cache[url] = content
fetched_contents.append((url, content))
print(f"Error fetching URL '{url}': {e}")
# Update fetched URLs
for url, _ in fetched_contents:
seq['executed_url_fetches'].add(url)
# Construct full page content string
fetched_pages = dict(fetched_contents)
full_page_str = json.dumps(fetched_pages, ensure_ascii=False, indent=2)
# Append full page content to the prompt
append_text = f"\n{BEGIN_FULL_PAGE}\n{full_page_str}\n{END_FULL_PAGE}\n"
seq['prompt'] += append_text
seq['output'] += append_text
# Update history
seq['history'].append(append_text)
print(f"Fetched and cached {len(fetched_contents)} URLs.")
# Check if URL fetch limit is reached
if len(seq['executed_url_fetches']) >= MAX_URL_FETCH:
limit_message = f"\n{BEGIN_FULL_PAGE}\nThe maximum number of URL fetches has been reached. You are not allowed to fetch more URLs.\n{END_FULL_PAGE}\n"
seq['prompt'] += limit_message
seq['output'] += limit_message
seq['history'].append(limit_message)
print("Reached URL fetch limit. Sequence marked as finished.")
# Continue to the next iteration if there are pending operations
if sequences_with_pending_ops:
continue # Process operations first
# Handle sequences needing generation
if sequences_needing_generation:
turn += 1
print(f"Turn {turn}: {len(sequences_needing_generation)} sequences need generation. Generating with LLM...")
outputs = run_generation(sequences_needing_generation, max_tokens)
print("Generation complete. Processing outputs...")
# Process each generated output
for seq, out in zip(sequences_needing_generation, outputs):
text = out.outputs[0].text
seq['history'].append(text)
# Append generated text to prompt and output
seq['prompt'] += text
seq['output'] += text
# Check if the generated content contains search queries or URL fetch requests
search_query = extract_between(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
url_fetch = extract_between(text, BEGIN_URL, END_URL)
if search_query:
# Check if search limit is not exceeded
if seq['search_count'] < MAX_SEARCH_LIMIT:
# Check if this search query has not been executed
if search_query not in seq['executed_search_queries']:
# Add search operation to pending queue
seq['pending_operations'].append({'type': 'search', 'content': search_query})
seq['executed_search_queries'].add(search_query)
print(f"Added pending search operation for query: {search_query}")
else:
print(f"Search query already executed: {search_query}")
else:
# Add limit message if search limit is exceeded
limit_message = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
seq['prompt'] += limit_message
seq['output'] += limit_message
seq['history'].append(limit_message)
print(f"Search limit exceeded for query: {search_query}")
if url_fetch:
# Check if URL fetch limit is not exceeded
if len(seq['executed_url_fetches']) < MAX_URL_FETCH:
# Split and check if URLs have already been fetched
urls = [u.strip() for u in url_fetch.split(",")]
urls_to_fetch = [u for u in urls if u not in seq['executed_url_fetches']]
if urls_to_fetch:
# Add URL fetch operation to pending queue
seq['pending_operations'].append({'type': 'fetch_url', 'content': ', '.join(urls_to_fetch)})
print(f"Added pending URL fetch operation for URLs: {urls_to_fetch}")
else:
print(f"All requested URLs have been fetched or exceeded the limit: {urls}")
else:
# Add limit message if URL fetch limit is exceeded
limit_message = f"\n{BEGIN_FULL_PAGE}\nThe maximum number of URL fetches has been reached. You are not allowed to fetch more URLs.\n{END_FULL_PAGE}\n"
seq['prompt'] += limit_message
seq['output'] += limit_message
seq['history'].append(limit_message)
print("URL fetch limit exceeded.")
# If no new operations are added, mark sequence as finished
if not search_query and not url_fetch:
seq['finished'] = True
print("Sequence marked as finished.")
# Check if all sequences are finished
unfinished = [seq for seq in active_sequences if not seq['finished']]
if not unfinished:
break
else:
if turn >= MAX_TURN:
print(f"Exceeded maximum number of turns ({MAX_TURN}). Stopping.")
break
# Optionally, implement a delay or other logic to prevent infinite loops
pass
total_time = time.time() - start_time
# Collect all outputs
output_list = [seq['output'] for seq in active_sequences]
# ---------------------- Evaluation ----------------------
print("Evaluating generated answers...")
run_evaluation(
filtered_data=filtered_data,
input_list=input_list,
output_list=output_list,
dataset_name=dataset_name,
output_dir=output_dir,
total_time=total_time,
split=split,
)
# ---------------------- Update Search and URL Cache ----------------------
print('Updating Search and URL Cache...')
# Load existing caches or initialize empty dictionaries
if os.path.exists(search_cache_path):
with open(search_cache_path, 'r', encoding='utf-8') as f:
search_cache_new = json.load(f)
else:
search_cache_new = {}
if os.path.exists(url_cache_path):
with open(url_cache_path, 'r', encoding='utf-8') as f:
url_cache_new = json.load(f)
else:
url_cache_new = {}
search_cache.update(search_cache_new)
url_cache.update(url_cache_new)
save_caches()
print("Process completed.")
if __name__ == "__main__":
main()
# run_search_o1.py
import os
import json
import time
import re
from tqdm import tqdm
import numpy as np
import torch
import string
from typing import Optional, Tuple, List, Dict
import argparse
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from bing_search import (
bing_web_search,
extract_relevant_info,
fetch_page_content,
extract_snippet_with_context
)
from evaluate import (
run_evaluation,
extract_answer
)
from prompts import (
get_gpqa_search_o1_instruction,
get_math_search_o1_instruction,
get_code_search_o1_instruction,
get_singleqa_search_o1_instruction,
get_multiqa_search_o1_instruction,
get_webpage_to_reasonchain_instruction,
get_task_instruction_openqa,
get_task_instruction_math,
get_task_instruction_multi_choice,
get_task_instruction_code,
)
# Define special tokens
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
END_SEARCH_QUERY = "<|end_search_query|>"
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
END_SEARCH_RESULT = "<|end_search_result|>"
def parse_args():
parser = argparse.ArgumentParser(description="Run Search O1 for various datasets and models.")
# Dataset and split configuration
parser.add_argument(
'--dataset_name',
type=str,
required=True,
choices=['gpqa', 'math500', 'aime', 'amc', 'livecode', 'nq', 'triviaqa', 'hotpotqa', '2wiki', 'musique', 'bamboogle'],
help="Name of the dataset to use."
)
parser.add_argument(
'--split',
type=str,
required=True,
choices=['test', 'diamond', 'main', 'extended'],
help="Dataset split to use."
)
parser.add_argument(
'--subset_num',
type=int,
default=-1,
help="Number of examples to process. Defaults to all if not specified."
)
# Search and document retrieval configuration
parser.add_argument(
'--max_search_limit',
type=int,
default=10,
help="Maximum number of searches per question."
)
parser.add_argument(
'--max_turn',
type=int,
default=15,
help="Maximum number of turns."
)
parser.add_argument(
'--top_k',
type=int,
default=10,
help="Maximum number of search documents to return."
)
parser.add_argument(
'--max_doc_len',
type=int,
default=3000,
help="Maximum length of each searched document."
)
parser.add_argument(
'--use_jina',
type=bool,
default=True,
help="Whether to use Jina API for document fetching."
)
parser.add_argument(
'--jina_api_key',
type=str,
default='None',
help="Your Jina API Key to Fetch URL Content."
)
# Model configuration
parser.add_argument(
'--model_path',
type=str,
required=True,
help="Path to the pre-trained model."
)
# Sampling parameters
parser.add_argument(
'--temperature',
type=float,
default=0.7,
help="Sampling temperature."
)
parser.add_argument(
'--top_p',
type=float,
default=0.8,
help="Top-p sampling parameter."
)
parser.add_argument(
'--top_k_sampling',
type=int,
default=20,
help="Top-k sampling parameter."
)
parser.add_argument(
'--repetition_penalty',
type=float,
default=None,
help="Repetition penalty. If not set, defaults based on the model."
)
parser.add_argument(
'--max_tokens',
type=int,
default=32768,
help="Maximum number of tokens to generate. If not set, defaults based on the model and dataset."
)
# Bing API Configuration
parser.add_argument(
'--bing_subscription_key',
type=str,
required=True,
help="Bing Search API subscription key."
)
parser.add_argument(
'--bing_endpoint',
type=str,
default="https://api.bing.microsoft.com/v7.0/search",
help="Bing Search API endpoint."
)
return parser.parse_args()
def main():
args = parse_args()
# Extract arguments
dataset_name = args.dataset_name
split = args.split
subset_num = args.subset_num
MAX_SEARCH_LIMIT = args.max_search_limit
MAX_TURN = args.max_turn
top_k = args.top_k
max_doc_len = args.max_doc_len
model_path = args.model_path
temperature = args.temperature
top_p = args.top_p
top_k_sampling = args.top_k_sampling
repetition_penalty = args.repetition_penalty
max_tokens = args.max_tokens
bing_subscription_key = args.bing_subscription_key
bing_endpoint = args.bing_endpoint
use_jina = args.use_jina
jina_api_key = args.jina_api_key
# Adjust parameters based on dataset
if dataset_name in ['nq', 'triviaqa', 'hotpotqa', 'musique', 'bamboogle', '2wiki', 'medmcqa', 'pubhealth']:
MAX_SEARCH_LIMIT = 5
if dataset_name in ['hotpotqa', 'musique', 'bamboogle', '2wiki']:
MAX_SEARCH_LIMIT = 10
MAX_TURN = 15
top_k = 10
max_doc_len = 3000
if args.jina_api_key == 'None':
jina_api_key = None
# Set default repetition_penalty if not provided
if repetition_penalty is None:
repetition_penalty = 1.05 if 'qwq' in model_path.lower() else 1.0
# Data paths based on dataset
if dataset_name == 'livecode':
data_path = f'./data/LiveCodeBench/{split}.json'
elif dataset_name in ['math500', 'gpqa', 'aime', 'amc']:
data_path = f'./data/{dataset_name.upper()}/{split}.json'
else:
data_path = f'./data/QA_Datasets/{dataset_name}.json'
print('-----------------------')
print(f'Using {dataset_name} {split} set.')
print('-----------------------')
# ---------------------- Caching Mechanism ----------------------
# Define cache directories and file paths
cache_dir = './cache'
search_cache_path = os.path.join(cache_dir, 'search_cache.json')
url_cache_path = os.path.join(cache_dir, 'url_cache.json')
# Ensure cache directory exists
os.makedirs(cache_dir, exist_ok=True)
# Load existing caches or initialize empty dictionaries
if os.path.exists(search_cache_path):
with open(search_cache_path, 'r', encoding='utf-8') as f:
search_cache = json.load(f)
else:
search_cache = {}
if os.path.exists(url_cache_path):
with open(url_cache_path, 'r', encoding='utf-8') as f:
url_cache = json.load(f)
else:
url_cache = {}
# Function to save caches
def save_caches():
with open(search_cache_path, 'w', encoding='utf-8') as f:
json.dump(search_cache, f, ensure_ascii=False, indent=2)
with open(url_cache_path, 'w', encoding='utf-8') as f:
json.dump(url_cache, f, ensure_ascii=False, indent=2)
# ---------------------- Model Loading ----------------------
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Define output directory based on model and dataset
if 'qwq' in model_path.lower():
if dataset_name in ['math500', 'gpqa', 'aime', 'amc', 'livecode']:
output_dir = f'./outputs/{dataset_name}.qwq.search_o1'
if dataset_name == 'gpqa' and (MAX_SEARCH_LIMIT != 5 or top_k != 10):
output_dir = f'./outputs/runs.analysis/{dataset_name}.qwq.search_o1.{MAX_SEARCH_LIMIT}.{top_k}'
else:
output_dir = f'./outputs/runs.qa/{dataset_name}.qwq.search_o1'
else:
model_short_name = model_path.split('/')[-1].lower().replace('-instruct', '')
output_dir = f'./outputs/runs.baselines/{dataset_name}.{model_short_name}.search_o1'
os.makedirs(output_dir, exist_ok=True)
# Initialize the LLM
llm = LLM(
model=model_path,
tensor_parallel_size=torch.cuda.device_count(),
gpu_memory_utilization=0.95,
)
# ---------------------- Data Loading ----------------------
with open(data_path, 'r', encoding='utf-8') as json_file:
filtered_data = json.load(json_file)
# ---------------------- Batch Generation Function ----------------------
def generate_webpage_to_reasonchain_batch(
original_questions: List[str],
prev_reasonings: List[str],
search_queries: List[str],
documents: List[str],
dataset_name: str,
batch_output_records: List[Dict], # New parameter to collect outputs
max_tokens: int = 32768,
coherent: bool = False,
) -> List[str]:
user_prompts = [
get_webpage_to_reasonchain_instruction(r, sq, doc)
for r, sq, doc in zip(prev_reasonings, search_queries, documents)
]
prompts = [{"role": "user", "content": up} for up in user_prompts]
prompts = [tokenizer.apply_chat_template([p], tokenize=False, add_generation_prompt=True) for p in prompts]
output = llm.generate(
prompts,
sampling_params=SamplingParams(
max_tokens=max_tokens,
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.05,
)
)
raw_outputs = [out.outputs[0].text for out in output]
extracted_infos = [extract_answer(raw, mode='infogen') for raw in raw_outputs]
for i, (p, r, e) in enumerate(zip(prompts, raw_outputs, extracted_infos)):
batch_output_records.append({
'prompt': p,
'raw_output': r,
'extracted_info': e
})
return extracted_infos
# ---------------------- Preparation of Input Prompts ----------------------
input_list = []
for item in filtered_data:
question = item['Question']
if dataset_name in ['nq', 'triviaqa', 'hotpotqa', 'musique', 'bamboogle', '2wiki']:
if dataset_name in ['nq', 'triviaqa']:
instruction = get_singleqa_search_o1_instruction(MAX_SEARCH_LIMIT)
elif dataset_name in ['hotpotqa', 'musique', 'bamboogle', '2wiki']:
instruction = get_multiqa_search_o1_instruction(MAX_SEARCH_LIMIT)
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_openqa(question, model_name='qwq')
else:
user_prompt = get_task_instruction_openqa(question)
elif dataset_name in ['math500', 'aime', 'amc']:
instruction = get_math_search_o1_instruction(MAX_SEARCH_LIMIT)
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_math(question, model_name='qwq')
else:
user_prompt = get_task_instruction_math(question)
elif dataset_name == 'gpqa':
instruction = get_gpqa_search_o1_instruction(MAX_SEARCH_LIMIT)
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_multi_choice(question, model_name='qwq')
elif 'llama' in model_path.lower():
user_prompt = get_task_instruction_multi_choice(question, model_name='llama')
else:
user_prompt = get_task_instruction_multi_choice(question)
elif dataset_name == 'livecode':
instruction = get_code_search_o1_instruction(MAX_SEARCH_LIMIT)
question_title = item.get('question_title', '')
if 'qwq' in model_path.lower():
user_prompt = get_task_instruction_code(question, question_title=question_title, model_name='qwq')
else:
user_prompt = get_task_instruction_code(question)
else:
user_prompt = "" # Default to empty if dataset not matched
prompt = [{"role": "user", "content": instruction + user_prompt}]
prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
input_list.append(prompt)
if subset_num != -1:
input_list = input_list[:subset_num]
filtered_data = filtered_data[:subset_num]
# Initialize active sequences
active_sequences = [{
'item': item,
'prompt': prompt,
'output': '',
'finished': False,
'history': [],
'search_count': 0,
'executed_search_queries': set(),
} for item, prompt in zip(filtered_data, input_list)]
# ---------------------- Set Max Tokens ----------------------
if 'qwq' in model_path.lower():
if dataset_name in ['aime', 'amc', 'livecode']:
max_tokens = 32768
else:
max_tokens = 20480
else:
max_tokens = 8192
# ---------------------- Generation Function ----------------------
def run_generation(sequences: List[Dict], max_tokens: int) -> List:
prompts = [s['prompt'] for s in sequences]
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k_sampling,
repetition_penalty=repetition_penalty,
stop=[END_SEARCH_QUERY, tokenizer.eos_token],
include_stop_str_in_output=True,
)
output_list = llm.generate(prompts, sampling_params=sampling_params)
return output_list
# Function to extract text between two tags
def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
matches = re.findall(pattern, text, flags=re.DOTALL)
if matches:
return matches[-1].strip()
return None
def replace_recent_steps(origin_str, replace_str):
"""
Replaces specific steps in the original reasoning steps with new steps.
If a replacement step contains "DELETE THIS STEP", that step is removed.
Parameters:
- origin_str (str): The original reasoning steps.
- replace_str (str): The steps to replace or delete.
Returns:
- str: The updated reasoning steps after applying replacements.
"""
def parse_steps(text):
"""
Parses the reasoning steps from a given text.
Parameters:
- text (str): The text containing reasoning steps.
Returns:
- dict: A dictionary mapping step numbers to their content.
"""
step_pattern = re.compile(r"Step\s+(\d+):\s*")
steps = {}
current_step_num = None
current_content = []
for line in text.splitlines():
step_match = step_pattern.match(line)
if step_match:
# If there's an ongoing step, save its content
if current_step_num is not None:
steps[current_step_num] = "\n".join(current_content).strip()
current_step_num = int(step_match.group(1))
content = line[step_match.end():].strip()
current_content = [content] if content else []
else:
if current_step_num is not None:
current_content.append(line)
# Save the last step if any
if current_step_num is not None:
steps[current_step_num] = "\n".join(current_content).strip()
return steps
# Parse the original and replacement steps
origin_steps = parse_steps(origin_str)
replace_steps = parse_steps(replace_str)
# Apply replacements
for step_num, content in replace_steps.items():
if "DELETE THIS STEP" in content:
# Remove the step if it exists
if step_num in origin_steps:
del origin_steps[step_num]
else:
# Replace or add the step
origin_steps[step_num] = content
# Sort the steps by step number
sorted_steps = sorted(origin_steps.items())
# Reconstruct the reasoning steps as a single string
new_reasoning_steps = "\n\n".join([f"{content}" for num, content in sorted_steps])
return new_reasoning_steps
# ---------------------- Initialize Collection Structure ----------------------
# Initialize a list to collect batch outputs
batch_output_records = []
start_time = time.time()
turn = 0
# Main loop until all sequences are finished or maximum turns reached
while True:
# Identify sequences that need generation
sequences_needing_generation = [seq for seq in active_sequences if not seq['finished']]
if sequences_needing_generation:
turn += 1
print(f'\n-------------- Turn {turn} --------------')
print(f"We have {len(sequences_needing_generation)} sequences needing generation...")
outputs = run_generation(sequences_needing_generation, max_tokens)
print("Generation completed, processing outputs...")
# Initialize batch variables
batch_relevant_info = []
batch_original_questions = []
batch_prev_reasonings = []
batch_search_queries = []
batch_documents = []
batch_sequences = []
# Collect URLs to fetch across all sequences
all_urls_to_fetch = set()
url_snippets = {}
url_sequence_map = {} # Map URL to list of sequences needing it
# Process each sequence and collect URLs
for seq, out in zip(sequences_needing_generation, outputs):
text = out.outputs[0].text
seq['history'].append(text)
# Append generated text to prompt and output
seq['prompt'] += text
seq['output'] += text
# Extract search query
search_query = extract_between(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
# If a search query is present and needs to be executed
if search_query and seq['output'].rstrip().endswith(END_SEARCH_QUERY):
if seq['search_count'] < MAX_SEARCH_LIMIT and search_query not in seq['executed_search_queries']:
# Execute search, use cache if available
if search_query in search_cache:
results = search_cache[search_query]
print(f"Using cached search results for query: \"{search_query}\"")
else:
try:
results = bing_web_search(search_query, bing_subscription_key, bing_endpoint, market='en-US', language='en')
search_cache[search_query] = results
print(f"Executed and cached search for query: \"{search_query}\"")
except Exception as e:
print(f"Error during search query '{search_query}': {e}")
search_cache[search_query] = {}
results = {}
# Extract relevant information from Bing search results
relevant_info = extract_relevant_info(results)[:top_k]
seq['relevant_info'] = relevant_info
# Extract URLs and snippets
urls_to_fetch = [it['url'] for it in relevant_info]
snippets = {info['url']: info['snippet'] for info in relevant_info if 'snippet' in info}
# Filter URLs that are not cached
urls_to_fetch_filtered = [u for u in urls_to_fetch if u not in url_cache]
cached_urls = [u for u in urls_to_fetch if u in url_cache]
# Store info for all_urls_to_fetch and url_snippets
for url in urls_to_fetch_filtered:
all_urls_to_fetch.add(url)
url_snippets[url] = snippets.get(url, "")
all_reasoning_steps = seq['output']
all_reasoning_steps = all_reasoning_steps.replace('\n\n', '\n').split("\n")
truncated_prev_reasoning = ""
for i, step in enumerate(all_reasoning_steps):
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
prev_steps = truncated_prev_reasoning.split('\n\n')
if len(prev_steps) <= 5:
truncated_prev_reasoning = '\n\n'.join(prev_steps)
else:
truncated_prev_reasoning = ''
for i, step in enumerate(prev_steps):
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
truncated_prev_reasoning += step + '\n\n'
else:
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
truncated_prev_reasoning += '...\n\n'
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
# Collect parameters for batch processing
batch_relevant_info.append(relevant_info)
batch_original_questions.append(seq['item']['Question'])
batch_prev_reasonings.append(truncated_prev_reasoning)
batch_search_queries.append(search_query)
batch_sequences.append(seq)
# Update search count and executed queries
seq['search_count'] += 1
seq['executed_search_queries'].add(search_query)
elif seq['search_count'] >= MAX_SEARCH_LIMIT:
limit_message = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
seq['prompt'] += limit_message
seq['output'] += limit_message
seq['history'].append(limit_message)
print(f"Search limit reached for query: \"{search_query}\"")
elif search_query in seq['executed_search_queries']:
limit_message = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
seq['prompt'] += limit_message
seq['output'] += limit_message
seq['history'].append(limit_message)
print(f"Repeated search for query: \"{search_query}\"")
else:
# If no search query needs to be executed, mark the sequence as finished
seq['finished'] = True
print("Sequence marked as complete.")
# Batch fetch all URLs at once to optimize speed
if all_urls_to_fetch:
print(f"Fetching {len(all_urls_to_fetch)} URLs...")
try:
fetched_contents = fetch_page_content(
list(all_urls_to_fetch),
use_jina=use_jina,
jina_api_key=jina_api_key,
# snippets=url_snippets # Do not pass snippets when updating url_cache directly
)
print(f"Fetched {len(fetched_contents)} URLs successfully.")
except Exception as e:
print(f"Error during batch URL fetching: {e}")
fetched_contents = {url: f"Error fetching URL: {e}" for url in all_urls_to_fetch}
# Update cache with fetched contents
for url, content in fetched_contents.items():
url_cache[url] = content
# After fetching, prepare formatted documents for batch processing
for relevant_info in batch_relevant_info:
formatted_documents = ""
for i, doc_info in enumerate(relevant_info):
url = doc_info['url']
raw_context = url_cache.get(url, "")
doc_info['snippet'] = doc_info['snippet'].replace('<b>','').replace('</b>','')
success, filtered_context = extract_snippet_with_context(raw_context, doc_info['snippet'], context_chars=max_doc_len)
if success:
context = filtered_context
else:
context = raw_context[:max_doc_len*2]
doc_info['context'] = context
formatted_documents += f"**Web Page {i + 1}:**\n"
formatted_documents += json.dumps(doc_info, ensure_ascii=False, indent=2) + "\n"
batch_documents.append(formatted_documents)
# After fetching, prepare for batch processing if there are any
if batch_sequences:
print(f"Batch processing {len(batch_sequences)} sequences with generate_webpage_to_reasonchain_batch...")
webpage_analyses = generate_webpage_to_reasonchain_batch(
original_questions=batch_original_questions,
prev_reasonings=batch_prev_reasonings,
search_queries=batch_search_queries,
documents=batch_documents,
dataset_name=dataset_name,
batch_output_records=batch_output_records, # Pass the collection list
max_tokens=max_tokens,
)
print("Batch generation completed, assigning outputs to sequences...")
for seq, analysis in zip(batch_sequences, webpage_analyses):
if isinstance(analysis, str):
append_text = f"\n\n{BEGIN_SEARCH_RESULT}{analysis}{END_SEARCH_RESULT}\n\n"
seq['prompt'] += append_text
seq['output'] += append_text
seq['history'].append(append_text)
else:
append_text = replace_recent_steps(seq['output'], analysis)
seq['prompt'] += append_text
seq['output'] += append_text
seq['history'].append(append_text)
# Check if all sequences are finished
unfinished = [seq for seq in active_sequences if not seq['finished']]
if not unfinished:
break
else:
if turn >= MAX_TURN:
print(f"Maximum number of turns ({MAX_TURN}) reached, stopping.")
break
total_time = time.time() - start_time
# ---------------------- Save Batch Output Records to JSON File ----------------------
# Define output JSON file path
t = time.localtime()
batch_output_file = os.path.join(output_dir, f'{split}.{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.info_extract.json')
# Save batch_output_records to JSON file
with open(batch_output_file, 'w', encoding='utf-8') as f:
json.dump(batch_output_records, f, ensure_ascii=False, indent=2)
print(f"Batch outputs saved to {batch_output_file}")
# Prepare output list for evaluation
output_list = [seq['output'] for seq in active_sequences]
# Run evaluation
run_evaluation(filtered_data, input_list, output_list, dataset_name, output_dir, total_time, split)
# ---------------------- Update Search and URL Cache ----------------------
print('Updating Search and URL Cache...')
# Load existing caches or initialize empty dictionaries
if os.path.exists(search_cache_path):
with open(search_cache_path, 'r', encoding='utf-8') as f:
search_cache_new = json.load(f)
else:
search_cache_new = {}
if os.path.exists(url_cache_path):
with open(url_cache_path, 'r', encoding='utf-8') as f:
url_cache_new = json.load(f)
else:
url_cache_new = {}
search_cache.update(search_cache_new)
url_cache.update(url_cache_new)
save_caches()
print("Process completed.")
if __name__ == "__main__":
main()
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
#print(string)
# remove inverse spaces
string = string.replace("\\!", "")
#print(string)
# replace \\ with \
string = string.replace("\\\\", "\\")
#print(string)
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
#print(string)
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
#print(string)
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = _strip_string(str1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except:
return str1 == str2
\ No newline at end of file
python scripts/run_search_o1.py \
--dataset_name aime \
--split test \
--max_search_limit 5 \
--max_turn 10 \
--top_k 10 \
--max_doc_len 3000 \
--use_jina False \
--model_path "Qwen2.5-72B-Instruct" \
--bing_subscription_key "1d3e8fb2-6a79-471d-89d5-dda796ca15fc"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment