import json import re from typing import Dict, Any WRONG_IDS_FILE = "wrong_ids.txt" # 保存错误样本的 unique_id def clean_think_tags(text: str) -> str: """ 移除模型可能生成的 \final ... \final 标签(非贪婪匹配), 以及常见的 <|end▁of▁thinking|>\n 等标签。 """ # 移除非贪婪匹配的 \final ... \final text = re.sub(r'\\final.*?\\final', '', text, flags=re.DOTALL) # 移除 \think ... \think 标签(如果存在) text = re.sub(r'\\think.*?\\think', '', text, flags=re.DOTALL) # 移除 <|end▁of▁thinking|>\n... 标签 text = re.sub(r'\\instant.*?(?=\n|$)', '', text, flags=re.DOTALL) return text.strip() def extract_boxed_answer(text: str) -> str: """ 从文本中提取最后一个 \boxed{...} 的内容(更鲁棒的嵌套处理)。 支持多重嵌套花括号。 """ # 查找所有 \boxed{ 的位置 pattern = r'\\boxed\{' matches = list(re.finditer(pattern, text)) if not matches: # 如果没有 boxed,返回最后一行非空文本作为备选 lines = [line.strip() for line in text.split('\n') if line.strip()] return lines[-1] if lines else text # 取最后一个 \boxed{ 开始的位置 last_match = matches[-1] start = last_match.end() - 1 # 指向 '{' 的位置 # 使用栈匹配括号 stack = [] i = start while i < len(text): if text[i] == '{': stack.append('{') elif text[i] == '}': stack.pop() if not stack: # 找到匹配的右括号 content = text[start+1:i] return content.strip() i += 1 # 如果括号不匹配,回退到简单正则 simple_match = re.search(r'\\boxed\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', text) if simple_match: return simple_match.group(1).strip() return text.strip() def normalize_math_answer(answer: str) -> str: """ 规范化数学答案,消除 LaTeX、空格、括号格式差异。 注意:此函数对分数、简单表达式有效,但可能不适用于所有符号答案。 """ # 移除空白 normalized = re.sub(r'\s+', '', answer) # 去除 \left, \right normalized = re.sub(r'\\left|\\right', '', normalized) # 移除 LaTeX 分组花括号(保留可能的结构,但为了比较,可移除非必要的) # 注意:此操作会改变结构,但对于分数等,比较 \frac1315 与 \frac{13}{15} 是等价的 normalized = normalized.replace('{', '').replace('}', '') # 统一括号 normalized = normalized.replace('[', '(').replace(']', ')') # 移除 \displaystyle normalized = re.sub(r'\\displaystyle', '', normalized) # 移除末尾的句点 normalized = normalized.rstrip('.') return normalized.strip() def compare_answers(predicted: str, ground_truth: str) -> bool: pred_norm = normalize_math_answer(predicted) truth_norm = normalize_math_answer(ground_truth) # 额外处理:如果预测答案包含在标准答案中(例如答案末尾多了标点) if pred_norm == truth_norm: return True # 尝试数值比较(如果都是数字或简单表达式) try: # 简单的表达式求值(仅限基本算术) pred_val = eval(pred_norm.replace('^', '**'), {"__builtins__": None}, {}) truth_val = eval(truth_norm.replace('^', '**'), {"__builtins__": None}, {}) return abs(pred_val - truth_val) < 1e-9 except: pass return False def load_ground_truth_jsonl(filepath: str) -> Dict[str, Dict[str, str]]: data_dict = {} with open(filepath, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: item = json.loads(line) unique_id = item.get('unique_id') if unique_id: data_dict[unique_id] = item else: print(f"警告: 数据集中发现样本缺少 unique_id: {item.get('problem', '')[:50]}...") return data_dict def main(): DATASET_JSONL = "/data1/sunzhq/llm-benchmark/MATH-500/test.jsonl" VLLM_RESULT_JSON = "/data1/sunzhq/llm-benchmark/results-1/performance_results/qwen3_8b_math500_perf.json" OUTPUT_EVAL_FILE = "evaluation_results.json" print(f"正在加载数据集: {DATASET_JSONL}") ground_truth_dict = load_ground_truth_jsonl(DATASET_JSONL) print(f"数据集索引建立完成,共包含 {len(ground_truth_dict)} 个唯一样本。") print(f"正在加载 vLLM 结果: {VLLM_RESULT_JSON}") with open(VLLM_RESULT_JSON, 'r', encoding='utf-8') as f: vllm_data = json.load(f) generated_texts = vllm_data.get('generated_texts', []) dataset_metadata_list = vllm_data.get('dataset_metadata', []) if len(generated_texts) != len(dataset_metadata_list): print(f"错误:生成结果数量 ({len(generated_texts)}) 与元数据数量 ({len(dataset_metadata_list)}) 不匹配!") return correct_count = 0 eval_results = [] wrong_ids = [] for idx, (raw_output, meta) in enumerate(zip(generated_texts, dataset_metadata_list)): unique_id = meta.get('unique_id') if not unique_id: print(f"警告: 第 {idx} 个样本在 metadata 中没有找到 unique_id。跳过。") continue ground_truth_item = ground_truth_dict.get(unique_id) if not ground_truth_item: print(f"警告: 在原始数据集中未找到 ID 为 {unique_id} 的样本。跳过。") continue ground_truth = ground_truth_item.get('answer', '') cleaned = clean_think_tags(raw_output) predicted = extract_boxed_answer(cleaned) is_correct = compare_answers(predicted, ground_truth) if is_correct: correct_count += 1 else: wrong_ids.append(unique_id) eval_results.append({ 'index': idx, 'unique_id': unique_id, 'problem': ground_truth_item.get('problem', '')[:100] + '...', 'predicted': predicted, 'ground_truth': ground_truth, 'correct': is_correct }) if (idx + 1) % 10 == 0: print(f"已处理 {idx+1} 个样本...") total_valid_samples = len(eval_results) accuracy = correct_count / total_valid_samples if total_valid_samples > 0 else 0.0 print("\n" + "=" * 60) print(f"评估完成!") print(f"参与评估的有效样本数: {total_valid_samples}") print(f"正确数量: {correct_count}") print(f"准确率 (Accuracy): {accuracy:.4f} ({accuracy*100:.2f}%)") # 保存详细结果 output_data = { 'summary': { 'total_processed': total_valid_samples, 'correct': correct_count, 'accuracy': accuracy }, 'details': eval_results } with open(OUTPUT_EVAL_FILE, 'w', encoding='utf-8') as f: json.dump(output_data, f, ensure_ascii=False, indent=2) print(f"详细评估结果已保存至: {OUTPUT_EVAL_FILE}") # 保存错误 ID if wrong_ids: with open(WRONG_IDS_FILE, 'w', encoding='utf-8') as f: for uid in wrong_ids: f.write(uid + '\n') print(f"错误 unique_id 列表已保存至: {WRONG_IDS_FILE} (共 {len(wrong_ids)} 个)") else: print("没有错误样本。") if __name__ == "__main__": main()