import json import numpy as np import argparse def load_json_file(file_path): """读取 JSON 文件并返回内容""" try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except FileNotFoundError: print(f"错误:文件 {file_path} 不存在") exit(1) except json.JSONDecodeError: print(f"错误:文件 {file_path} 不是有效的 JSON 格式") exit(1) def calculate_mae(file1_path, file2_path): """计算两个 JSON 文件中 logprobs 的平均绝对误差""" data1 = load_json_file(file1_path) data2 = load_json_file(file2_path) if len(data1) != len(data2): print("错误:两个 JSON 文件的字典数量不同") exit(1) mae_per_token = np.zeros(10) num_entries = len(data1) overall_error = 0 for i in range(num_entries): prompt = data1[i]["input"] logprobs1 = data1[i]["logprobs_of_rank1_for_the_first_10_tokens"] logprobs2 = data2[i]["logprobs_of_rank1_for_the_first_10_tokens"] if len(logprobs1) != 10 or len(logprobs2) != 10: print(f"错误:第 {i+1} 个字典的 logprobs 数组长度不为 10") exit(1) current_mae = np.mean(np.abs(np.array(logprobs1) - np.array(logprobs2))) print(f"提示词:{prompt},平均绝对误差:{current_mae}") overall_error +=current_mae overall_mae = overall_error/num_entries print(f"\n总体平均绝对误差:{overall_mae:.6e}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="计算两个 JSON 文件中 logprobs 的平均绝对误差") parser.add_argument("--file1", help="第一个 JSON 文件的路径") parser.add_argument("--file2", help="第二个 JSON 文件的路径") args = parser.parse_args() # 调用计算函数 calculate_mae(args.file1, args.file2)