# -*- coding: utf-8 -*- """ 一个用于比较两个 PyTorch checkpoint (.pt 或 .ckpt) 文件中模型权重的脚本。 它会逐层比较权重,并根据预设的“平均绝对差异”阈值来判断是否“过关”。 """ import torch from collections import OrderedDict # ============================================================================== # 1. 配置区域: 文件路径、模型权重 Key 和判断阈值 # ============================================================================== CKPT_PATH_1 = '/home/zwq/project/shangchaun/external/graphormer_pytorch/res/res_of_A800/checkpoint1.pt' CKPT_PATH_2 = '/home/zwq/project/shangchaun/external/graphormer_pytorch/res/res_of_K100AI/checkpoint1.pt' # 我们已经通过探查得知,模型权重存储在 'model' 这个键下 MODEL_WEIGHTS_KEY = 'model' # !! 核心判断标准 !! # 设置平均绝对差异的阈值,如果所有层的差异都小于此值,则认为“过关” MEAN_ABS_DIFF_THRESHOLD = 0.02 # ============================================================================== def extract_state_dict(checkpoint, model_key): """从加载的 checkpoint 对象中提取 state_dict。""" if not isinstance(checkpoint, dict): raise TypeError(f"Checkpoint 文件加载后不是一个字典,而是一个 {type(checkpoint)}。") if model_key in checkpoint: return checkpoint[model_key] else: keys_found = list(checkpoint.keys()) raise KeyError( f"在 checkpoint 中找不到指定的键 '{model_key}'。\n" f"文件中实际存在的键是: {keys_found}" ) def normalize_keys(state_dict): """移除常见的 state_dict key 前缀,如 'module.'。""" new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('module.'): name = k[7:] # 移除 'module.' else: name = k new_state_dict[name] = v return new_state_dict def compare_checkpoints(ckpt_path1, ckpt_path2, model_key, threshold): """加载并比较两个 checkpoint 文件的主函数。""" print(f"[*] 正在加载 Checkpoint 1: {ckpt_path1}") ckpt1 = torch.load(ckpt_path1, map_location='cpu') print(f"[*] 正在加载 Checkpoint 2: {ckpt_path2}") ckpt2 = torch.load(ckpt_path2, map_location='cpu') print(f"\n[*] 正在从键 '{model_key}' 中提取并标准化 state_dict...") sd1 = normalize_keys(extract_state_dict(ckpt1, model_key)) sd2 = normalize_keys(extract_state_dict(ckpt2, model_key)) keys1, keys2 = set(sd1.keys()), set(sd2.keys()) common_keys = sorted(list(keys1.intersection(keys2))) unique_to_1, unique_to_2 = sorted(list(keys1 - keys2)), sorted(list(keys2 - keys1)) print("\n" + "="*60) print(" 层名称比较摘要 (Layer Name Comparison Summary)") print("="*60) print(f"总层数 (文件1): {len(keys1)}") print(f"总层数 (文件2): {len(keys2)}") print(f"共有层数: {len(common_keys)}") if unique_to_1: print(f"文件1独有层数: {len(unique_to_1)}") if unique_to_2: print(f"文件2独有层数: {len(unique_to_2)}") print("\n" + "="*60) print(" 共有层权重差异详细分析 (Shared Layer Weight-Diff Analysis)") print(f" - 阈值 (Threshold for Mean Abs Diff): {threshold}") print("="*60) failing_layers = [] for key in common_keys: tensor1, tensor2 = sd1[key], sd2[key] if tensor1.shape != tensor2.shape: print(f"层: {key} - [形状不匹配!] Shape Mismatch! {tensor1.shape} vs {tensor2.shape}") failing_layers.append((key, float('inf'), "形状不匹配")) # 标记为失败 continue if torch.equal(tensor1, tensor2): continue # 完全相同则跳过,保持输出简洁 abs_diff = torch.abs(tensor1.float() - tensor2.float()) mean_abs_diff = abs_diff.mean().item() # 核心检查:平均绝对差异是否超过阈值 if mean_abs_diff > threshold: status = f"❌ [不通过] (>{threshold})" failing_layers.append((key, mean_abs_diff, "超过阈值")) else: status = f"✅ [通过] (<={threshold})" print(f"层: {key}") print(f" - 平均绝对差 (Mean Abs Diff): {mean_abs_diff:.8f} --- {status}") print("\n" + "="*60) print(" 最终总结 (Final Conclusion)") print("="*60) # 检查结构是否完全一致 if unique_to_1 or unique_to_2: print("警告: 两个模型的层结构不完全一致,存在独有层。") print(" - 文件1 独有层:", unique_to_1 if unique_to_1 else "无") print(" - 文件2 独有层:", unique_to_2 if unique_to_2 else "无") print("-" * 20) # 根据 failing_layers 列表给出最终结论 if not failing_layers: print(f"✅ 过关 (PASS): 所有共有层的平均绝对差异都在阈值 {threshold} 之内。") else: print(f"❌ 不通过 (FAIL): 发现 {len(failing_layers)} 个层的差异不满足要求。") print("\n详细信息如下:") for layer_name, diff_value, reason in failing_layers: if reason == "形状不匹配": print(f" - 层: {layer_name}, 原因: {reason}") else: print(f" - 层: {layer_name}, 平均绝对差: {diff_value:.8f} (原因: {reason})") if __name__ == '__main__': try: compare_checkpoints(CKPT_PATH_1, CKPT_PATH_2, MODEL_WEIGHTS_KEY, MEAN_ABS_DIFF_THRESHOLD) except Exception as e: print(f"\n[程序执行出错]: {e}")