compare_ckpts.py 5.5 KB
Newer Older
zhangwq5's avatar
all  
zhangwq5 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
"""
一个用于比较两个 PyTorch checkpoint (.pt 或 .ckpt) 文件中模型权重的脚本。
它会逐层比较权重,并根据预设的“平均绝对差异”阈值来判断是否“过关”。
"""
import torch
from collections import OrderedDict

# ==============================================================================
# 1. 配置区域: 文件路径、模型权重 Key 和判断阈值
# ==============================================================================
CKPT_PATH_1 = '/home/zwq/project/shangchaun/external/graphormer_pytorch/res/res_of_A800/checkpoint1.pt'
CKPT_PATH_2 = '/home/zwq/project/shangchaun/external/graphormer_pytorch/res/res_of_K100AI/checkpoint1.pt'

# 我们已经通过探查得知,模型权重存储在 'model' 这个键下
MODEL_WEIGHTS_KEY = 'model'

# !! 核心判断标准 !!
# 设置平均绝对差异的阈值,如果所有层的差异都小于此值,则认为“过关”
MEAN_ABS_DIFF_THRESHOLD = 0.02
# ==============================================================================

def extract_state_dict(checkpoint, model_key):
    """从加载的 checkpoint 对象中提取 state_dict。"""
    if not isinstance(checkpoint, dict):
         raise TypeError(f"Checkpoint 文件加载后不是一个字典,而是一个 {type(checkpoint)}。")
    
    if model_key in checkpoint:
        return checkpoint[model_key]
    else:
        keys_found = list(checkpoint.keys())
        raise KeyError(
            f"在 checkpoint 中找不到指定的键 '{model_key}'。\n"
            f"文件中实际存在的键是: {keys_found}"
        )

def normalize_keys(state_dict):
    """移除常见的 state_dict key 前缀,如 'module.'。"""
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:]  # 移除 'module.'
        else:
            name = k
        new_state_dict[name] = v
    return new_state_dict

def compare_checkpoints(ckpt_path1, ckpt_path2, model_key, threshold):
    """加载并比较两个 checkpoint 文件的主函数。"""
    print(f"[*] 正在加载 Checkpoint 1: {ckpt_path1}")
    ckpt1 = torch.load(ckpt_path1, map_location='cpu')
    print(f"[*] 正在加载 Checkpoint 2: {ckpt_path2}")
    ckpt2 = torch.load(ckpt_path2, map_location='cpu')

    print(f"\n[*] 正在从键 '{model_key}' 中提取并标准化 state_dict...")
    sd1 = normalize_keys(extract_state_dict(ckpt1, model_key))
    sd2 = normalize_keys(extract_state_dict(ckpt2, model_key))

    keys1, keys2 = set(sd1.keys()), set(sd2.keys())
    common_keys = sorted(list(keys1.intersection(keys2)))
    unique_to_1, unique_to_2 = sorted(list(keys1 - keys2)), sorted(list(keys2 - keys1))

    print("\n" + "="*60)
    print("      层名称比较摘要 (Layer Name Comparison Summary)")
    print("="*60)
    print(f"总层数 (文件1): {len(keys1)}")
    print(f"总层数 (文件2): {len(keys2)}")
    print(f"共有层数: {len(common_keys)}")
    if unique_to_1: print(f"文件1独有层数: {len(unique_to_1)}")
    if unique_to_2: print(f"文件2独有层数: {len(unique_to_2)}")
    
    print("\n" + "="*60)
    print("      共有层权重差异详细分析 (Shared Layer Weight-Diff Analysis)")
    print(f"      - 阈值 (Threshold for Mean Abs Diff): {threshold}")
    print("="*60)

    failing_layers = []

    for key in common_keys:
        tensor1, tensor2 = sd1[key], sd2[key]

        if tensor1.shape != tensor2.shape:
            print(f"层: {key} - [形状不匹配!] Shape Mismatch! {tensor1.shape} vs {tensor2.shape}")
            failing_layers.append((key, float('inf'), "形状不匹配")) # 标记为失败
            continue

        if torch.equal(tensor1, tensor2):
            continue  # 完全相同则跳过,保持输出简洁

        abs_diff = torch.abs(tensor1.float() - tensor2.float())
        mean_abs_diff = abs_diff.mean().item()

        # 核心检查:平均绝对差异是否超过阈值
        if mean_abs_diff > threshold:
            status = f"❌ [不通过] (>{threshold})"
            failing_layers.append((key, mean_abs_diff, "超过阈值"))
        else:
            status = f"✅ [通过] (<={threshold})"

        print(f"层: {key}")
        print(f"  - 平均绝对差 (Mean Abs Diff): {mean_abs_diff:.8f} --- {status}")

    print("\n" + "="*60)
    print("      最终总结 (Final Conclusion)")
    print("="*60)

    # 检查结构是否完全一致
    if unique_to_1 or unique_to_2:
        print("警告: 两个模型的层结构不完全一致,存在独有层。")
        print("  - 文件1 独有层:", unique_to_1 if unique_to_1 else "无")
        print("  - 文件2 独有层:", unique_to_2 if unique_to_2 else "无")
        print("-" * 20)

    # 根据 failing_layers 列表给出最终结论
    if not failing_layers:
        print(f"✅ 过关 (PASS): 所有共有层的平均绝对差异都在阈值 {threshold} 之内。")
    else:
        print(f"❌ 不通过 (FAIL): 发现 {len(failing_layers)} 个层的差异不满足要求。")
        print("\n详细信息如下:")
        for layer_name, diff_value, reason in failing_layers:
            if reason == "形状不匹配":
                print(f"  - 层: {layer_name}, 原因: {reason}")
            else:
                print(f"  - 层: {layer_name}, 平均绝对差: {diff_value:.8f} (原因: {reason})")

if __name__ == '__main__':
    try:
        compare_checkpoints(CKPT_PATH_1, CKPT_PATH_2, MODEL_WEIGHTS_KEY, MEAN_ABS_DIFF_THRESHOLD)
    except Exception as e:
        print(f"\n[程序执行出错]: {e}")