import json
import numpy as np
import argparse

def load_json_file(file_path):
    """读取 JSON 文件并返回内容"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"错误：文件 {file_path} 不存在")
        exit(1)
    except json.JSONDecodeError:
        print(f"错误：文件 {file_path} 不是有效的 JSON 格式")
        exit(1)

def calculate_mae(file1_path, file2_path):
    """计算两个 JSON 文件中 logprobs 的平均绝对误差"""

    data1 = load_json_file(file1_path)
    data2 = load_json_file(file2_path)

    if len(data1) != len(data2):
        print("错误：两个 JSON 文件的字典数量不同")
        exit(1)

    mae_per_token = np.zeros(10)  
    num_entries = len(data1)
    overall_error = 0

    for i in range(num_entries):
        prompt = data1[i]["input"]
        logprobs1 = data1[i]["logprobs_of_rank1_for_the_first_10_tokens"]
        logprobs2 = data2[i]["logprobs_of_rank1_for_the_first_10_tokens"]

        if len(logprobs1) != 10 or len(logprobs2) != 10:
            print(f"错误：第 {i+1} 个字典的 logprobs 数组长度不为 10")
            exit(1)

        current_mae = np.mean(np.abs(np.array(logprobs1) - np.array(logprobs2)))
        
        print(f"提示词:{prompt},平均绝对误差:{current_mae}")
        overall_error +=current_mae

    overall_mae = overall_error/num_entries
    print(f"\n总体平均绝对误差：{overall_mae:.6e}")

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="计算两个 JSON 文件中 logprobs 的平均绝对误差")
    parser.add_argument("--file1", help="第一个 JSON 文件的路径")
    parser.add_argument("--file2", help="第二个 JSON 文件的路径")
    args = parser.parse_args()

    # 调用计算函数
    calculate_mae(args.file1, args.file2)