import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import os
import csv

def calculate_mae(image_a, image_b):
    """计算两张图片的平均绝对误差 (Mean Absolute Error, MAE)"""
    image_a = image_a.astype(np.float32)
    image_b = image_b.astype(np.float32)
    mae_value = np.mean(np.abs(image_a - image_b))
    return mae_value

def compare_images(image_path1, image_path2):
    """加载两张图片并计算它们的MAE, PSNR, 和 SSIM"""
    try:
        img1_pil = Image.open(image_path1).convert('RGB')
        img2_pil = Image.open(image_path2).convert('RGB')
    except FileNotFoundError as e:
        print(f"错误: 无法找到文件。 {e}")
        return None
        
    img1_np = np.array(img1_pil)
    img2_np = np.array(img2_pil)

    if img1_np.shape != img2_np.shape:
        print(f"错误: 图片 '{os.path.basename(image_path1)}' 尺寸不匹配。")
        print(f"  - 图片1尺寸: {img1_np.shape}")
        print(f"  - 图片2尺寸: {img2_np.shape}")
        return None

    mae_value = calculate_mae(img1_np, img2_np)
    psnr_value = psnr(img1_np, img2_np, data_range=255)

    try:
        ssim_value = ssim(img1_np, img2_np, data_range=255, channel_axis=-1, win_size=7)
    except TypeError:
        ssim_value = ssim(img1_np, img2_np, data_range=255, multichannel=True, win_size=7)

    return {"MAE": mae_value, "PSNR": psnr_value, "SSIM": ssim_value}


if __name__ == "__main__":
    
    # *****************************************************************
    # *  请在这里修改你的文件夹路径                                   *
    # *****************************************************************
    folder_gpu = "/home/zwq/project/shangchaun/external/qwen-image_hf/infer/generated_images_GPU"         # 基准文件夹 (GPU推理结果)
    folder_dcu = "/home/zwq/project/shangchaun/external/qwen-image_hf/infer/generated_images_DCU"         # 待测试文件夹 (DCU推理结果)
    report_filename = "comparison_report.csv"  

    if not os.path.isdir(folder_gpu) or not os.path.isdir(folder_dcu):
        print(f"错误：请确保文件夹 '{folder_gpu}' 和 '{folder_dcu}' 都存在。")
        exit()

    print(f"开始对比文件夹 '{folder_gpu}' (基准) 和 '{folder_dcu}' (测试)...")

    all_metrics = []
    
    base_filenames = sorted(os.listdir(folder_gpu))
    
    for filename in base_filenames:
        if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
            continue

        path1 = os.path.join(folder_gpu, filename)
        path2 = os.path.join(folder_dcu, filename)

        if not os.path.exists(path2):
            print(f"  [跳过] 在 '{folder_dcu}' 中未找到对应的文件: {filename}")
            continue

        results = compare_images(path1, path2)

        if results:
            print(f"  - 对比 {filename}: "
                  f"MAE={results['MAE']:.4f}, "
                  f"PSNR={results['PSNR']:.2f}dB, "
                  f"SSIM={results['SSIM']:.4f}")
            
            all_metrics.append({
                'filename': filename,
                'MAE': results['MAE'],
                'PSNR': results['PSNR'],
                'SSIM': results['SSIM']
            })

    # --- 计算并打印平均结果 ---
    if not all_metrics:
        print("\n未找到任何可以对比的图片对。请检查文件夹内容和文件名。")
    else:
        # 使用Numpy高效计算平均值
        avg_mae = np.mean([m['MAE'] for m in all_metrics])
        avg_psnr = np.mean([m['PSNR'] for m in all_metrics])
        avg_ssim = np.mean([m['SSIM'] for m in all_metrics])

        print("\n" + "="*50)
        print("--- 批量对比平均结果 ---")
        print(f"成功对比图片对数: {len(all_metrics)}")
        print(f"平均绝对误差 (MAE): {avg_mae:.4f}")
        print(f"平均峰值信噪比 (PSNR): {avg_psnr:.2f} dB")
        print(f"平均结构相似性 (SSIM): {avg_ssim:.4f}")
        print("="*50)
        
        # --- 将详细结果写入CSV文件 ---
        try:
            with open(report_filename, 'w', newline='', encoding='utf-8') as csvfile:
                fieldnames = ['filename', 'MAE', 'PSNR', 'SSIM']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                writer.writeheader()
                writer.writerows(all_metrics)
                # 写入平均值
                writer.writerow({}) 
                writer.writerow({'filename': 'Average', 'MAE': avg_mae, 'PSNR': avg_psnr, 'SSIM': avg_ssim})
            
            print(f"\n详细报告已保存至: {report_filename}")
        except Exception as e:
            print(f"\n保存报告失败: {e}")