run_detailed_benchmark.py 7.32 KB
Newer Older
sunzhq2's avatar
init  
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import json
import re
from typing import Dict, Any

WRONG_IDS_FILE = "wrong_ids.txt"   # 保存错误样本的 unique_id

def clean_think_tags(text: str) -> str:
    """
    移除模型可能生成的 \final ... \final 标签(非贪婪匹配),
    以及常见的 <|end▁of▁thinking|>\n 等标签。
    """
    # 移除非贪婪匹配的 \final ... \final
    text = re.sub(r'\\final.*?\\final', '', text, flags=re.DOTALL)
    # 移除 \think ... \think 标签(如果存在)
    text = re.sub(r'\\think.*?\\think', '', text, flags=re.DOTALL)
    # 移除 <|end▁of▁thinking|>\n... 标签
    text = re.sub(r'\\instant.*?(?=\n|$)', '', text, flags=re.DOTALL)
    return text.strip()

def extract_boxed_answer(text: str) -> str:
    """
    从文本中提取最后一个 \boxed{...} 的内容(更鲁棒的嵌套处理)。
    支持多重嵌套花括号。
    """
    # 查找所有 \boxed{ 的位置
    pattern = r'\\boxed\{'
    matches = list(re.finditer(pattern, text))
    if not matches:
        # 如果没有 boxed,返回最后一行非空文本作为备选
        lines = [line.strip() for line in text.split('\n') if line.strip()]
        return lines[-1] if lines else text

    # 取最后一个 \boxed{ 开始的位置
    last_match = matches[-1]
    start = last_match.end() - 1  # 指向 '{' 的位置
    # 使用栈匹配括号
    stack = []
    i = start
    while i < len(text):
        if text[i] == '{':
            stack.append('{')
        elif text[i] == '}':
            stack.pop()
            if not stack:
                # 找到匹配的右括号
                content = text[start+1:i]
                return content.strip()
        i += 1
    # 如果括号不匹配,回退到简单正则
    simple_match = re.search(r'\\boxed\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', text)
    if simple_match:
        return simple_match.group(1).strip()
    return text.strip()

def normalize_math_answer(answer: str) -> str:
    """
    规范化数学答案,消除 LaTeX、空格、括号格式差异。
    注意:此函数对分数、简单表达式有效,但可能不适用于所有符号答案。
    """
    # 移除空白
    normalized = re.sub(r'\s+', '', answer)
    # 去除 \left, \right
    normalized = re.sub(r'\\left|\\right', '', normalized)
    # 移除 LaTeX 分组花括号(保留可能的结构,但为了比较,可移除非必要的)
    # 注意:此操作会改变结构,但对于分数等,比较 \frac1315 与 \frac{13}{15} 是等价的
    normalized = normalized.replace('{', '').replace('}', '')
    # 统一括号
    normalized = normalized.replace('[', '(').replace(']', ')')
    # 移除 \displaystyle
    normalized = re.sub(r'\\displaystyle', '', normalized)
    # 移除末尾的句点
    normalized = normalized.rstrip('.')
    return normalized.strip()

def compare_answers(predicted: str, ground_truth: str) -> bool:
    pred_norm = normalize_math_answer(predicted)
    truth_norm = normalize_math_answer(ground_truth)
    # 额外处理:如果预测答案包含在标准答案中(例如答案末尾多了标点)
    if pred_norm == truth_norm:
        return True
    # 尝试数值比较(如果都是数字或简单表达式)
    try:
        # 简单的表达式求值(仅限基本算术)
        pred_val = eval(pred_norm.replace('^', '**'), {"__builtins__": None}, {})
        truth_val = eval(truth_norm.replace('^', '**'), {"__builtins__": None}, {})
        return abs(pred_val - truth_val) < 1e-9
    except:
        pass
    return False

def load_ground_truth_jsonl(filepath: str) -> Dict[str, Dict[str, str]]:
    data_dict = {}
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                item = json.loads(line)
                unique_id = item.get('unique_id')
                if unique_id:
                    data_dict[unique_id] = item
                else:
                    print(f"警告: 数据集中发现样本缺少 unique_id: {item.get('problem', '')[:50]}...")
    return data_dict

def main():
    DATASET_JSONL = "/data1/sunzhq/llm-benchmark/MATH-500/test.jsonl"
    VLLM_RESULT_JSON = "/data1/sunzhq/llm-benchmark/results-1/performance_results/qwen3_8b_math500_perf.json"
    OUTPUT_EVAL_FILE = "evaluation_results.json"

    print(f"正在加载数据集: {DATASET_JSONL}")
    ground_truth_dict = load_ground_truth_jsonl(DATASET_JSONL)
    print(f"数据集索引建立完成,共包含 {len(ground_truth_dict)} 个唯一样本。")

    print(f"正在加载 vLLM 结果: {VLLM_RESULT_JSON}")
    with open(VLLM_RESULT_JSON, 'r', encoding='utf-8') as f:
        vllm_data = json.load(f)

    generated_texts = vllm_data.get('generated_texts', [])
    dataset_metadata_list = vllm_data.get('dataset_metadata', [])

    if len(generated_texts) != len(dataset_metadata_list):
        print(f"错误:生成结果数量 ({len(generated_texts)}) 与元数据数量 ({len(dataset_metadata_list)}) 不匹配!")
        return

    correct_count = 0
    eval_results = []
    wrong_ids = []

    for idx, (raw_output, meta) in enumerate(zip(generated_texts, dataset_metadata_list)):
        unique_id = meta.get('unique_id')
        if not unique_id:
            print(f"警告: 第 {idx} 个样本在 metadata 中没有找到 unique_id。跳过。")
            continue

        ground_truth_item = ground_truth_dict.get(unique_id)
        if not ground_truth_item:
            print(f"警告: 在原始数据集中未找到 ID 为 {unique_id} 的样本。跳过。")
            continue

        ground_truth = ground_truth_item.get('answer', '')
        cleaned = clean_think_tags(raw_output)
        predicted = extract_boxed_answer(cleaned)

        is_correct = compare_answers(predicted, ground_truth)
        if is_correct:
            correct_count += 1
        else:
            wrong_ids.append(unique_id)

        eval_results.append({
            'index': idx,
            'unique_id': unique_id,
            'problem': ground_truth_item.get('problem', '')[:100] + '...',
            'predicted': predicted,
            'ground_truth': ground_truth,
            'correct': is_correct
        })

        if (idx + 1) % 10 == 0:
            print(f"已处理 {idx+1} 个样本...")

    total_valid_samples = len(eval_results)
    accuracy = correct_count / total_valid_samples if total_valid_samples > 0 else 0.0
    print("\n" + "=" * 60)
    print(f"评估完成!")
    print(f"参与评估的有效样本数: {total_valid_samples}")
    print(f"正确数量: {correct_count}")
    print(f"准确率 (Accuracy): {accuracy:.4f} ({accuracy*100:.2f}%)")

    # 保存详细结果
    output_data = {
        'summary': {
            'total_processed': total_valid_samples,
            'correct': correct_count,
            'accuracy': accuracy
        },
        'details': eval_results
    }
    with open(OUTPUT_EVAL_FILE, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=2)
    print(f"详细评估结果已保存至: {OUTPUT_EVAL_FILE}")

    # 保存错误 ID
    if wrong_ids:
        with open(WRONG_IDS_FILE, 'w', encoding='utf-8') as f:
            for uid in wrong_ids:
                f.write(uid + '\n')
        print(f"错误 unique_id 列表已保存至: {WRONG_IDS_FILE} (共 {len(wrong_ids)} 个)")
    else:
        print("没有错误样本。")

if __name__ == "__main__":
    main()