acc.py 4.58 KB
Newer Older
zhangwq5's avatar
edit  
zhangwq5 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import os
import csv

def calculate_mae(image_a, image_b):
    """计算两张图片的平均绝对误差 (Mean Absolute Error, MAE)"""
    image_a = image_a.astype(np.float32)
    image_b = image_b.astype(np.float32)
    mae_value = np.mean(np.abs(image_a - image_b))
    return mae_value

def compare_images(image_path1, image_path2):
    """加载两张图片并计算它们的MAE, PSNR, 和 SSIM"""
    try:
        img1_pil = Image.open(image_path1).convert('RGB')
        img2_pil = Image.open(image_path2).convert('RGB')
    except FileNotFoundError as e:
        print(f"错误: 无法找到文件。 {e}")
        return None
        
    img1_np = np.array(img1_pil)
    img2_np = np.array(img2_pil)

    if img1_np.shape != img2_np.shape:
        print(f"错误: 图片 '{os.path.basename(image_path1)}' 尺寸不匹配。")
        print(f"  - 图片1尺寸: {img1_np.shape}")
        print(f"  - 图片2尺寸: {img2_np.shape}")
        return None

    mae_value = calculate_mae(img1_np, img2_np)
    psnr_value = psnr(img1_np, img2_np, data_range=255)

    try:
        ssim_value = ssim(img1_np, img2_np, data_range=255, channel_axis=-1, win_size=7)
    except TypeError:
        ssim_value = ssim(img1_np, img2_np, data_range=255, multichannel=True, win_size=7)

    return {"MAE": mae_value, "PSNR": psnr_value, "SSIM": ssim_value}


if __name__ == "__main__":
    
    # *****************************************************************
    # *  请在这里修改你的文件夹路径                                   *
    # *****************************************************************
    folder_gpu = "./output_images_A800"         # 基准文件夹 (GPU推理结果)
    folder_dcu = "./output_images_K100AI"         # 待测试文件夹 (DCU推理结果)
    report_filename = "comparison_report.csv"  

    if not os.path.isdir(folder_gpu) or not os.path.isdir(folder_dcu):
        print(f"错误:请确保文件夹 '{folder_gpu}' 和 '{folder_dcu}' 都存在。")
        exit()

    print(f"开始对比文件夹 '{folder_gpu}' (基准) 和 '{folder_dcu}' (测试)...")

    all_metrics = []
    
    base_filenames = sorted(os.listdir(folder_gpu))
    
    for filename in base_filenames:
        if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
            continue

        path1 = os.path.join(folder_gpu, filename)
        path2 = os.path.join(folder_dcu, filename)

        if not os.path.exists(path2):
            print(f"  [跳过] 在 '{folder_dcu}' 中未找到对应的文件: {filename}")
            continue

        results = compare_images(path1, path2)

        if results:
            print(f"  - 对比 {filename}: "
                  f"MAE={results['MAE']:.4f}, "
                  f"PSNR={results['PSNR']:.2f}dB, "
                  f"SSIM={results['SSIM']:.4f}")
            
            all_metrics.append({
                'filename': filename,
                'MAE': results['MAE'],
                'PSNR': results['PSNR'],
                'SSIM': results['SSIM']
            })

    # --- 计算并打印平均结果 ---
    if not all_metrics:
        print("\n未找到任何可以对比的图片对。请检查文件夹内容和文件名。")
    else:
        # 使用Numpy高效计算平均值
        avg_mae = np.mean([m['MAE'] for m in all_metrics])
        avg_psnr = np.mean([m['PSNR'] for m in all_metrics])
        avg_ssim = np.mean([m['SSIM'] for m in all_metrics])

        print("\n" + "="*50)
        print("--- 批量对比平均结果 ---")
        print(f"成功对比图片对数: {len(all_metrics)}")
        print(f"平均绝对误差 (MAE): {avg_mae:.4f}")
        print(f"平均峰值信噪比 (PSNR): {avg_psnr:.2f} dB")
        print(f"平均结构相似性 (SSIM): {avg_ssim:.4f}")
        print("="*50)
        
        # --- 将详细结果写入CSV文件 ---
        try:
            with open(report_filename, 'w', newline='', encoding='utf-8') as csvfile:
                fieldnames = ['filename', 'MAE', 'PSNR', 'SSIM']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                writer.writeheader()
                writer.writerows(all_metrics)
                # 写入平均值
                writer.writerow({}) 
                writer.writerow({'filename': 'Average', 'MAE': avg_mae, 'PSNR': avg_psnr, 'SSIM': avg_ssim})
            
            print(f"\n详细报告已保存至: {report_filename}")
        except Exception as e:
            print(f"\n保存报告失败: {e}")