import os
import numpy as np
from PIL import Image
import argparse
import lfw
import sys
from sklearn import metrics
from scipy.optimize import brentq
from scipy import interpolate
import time
from tqdm import tqdm
import migraphx

def AllocateOutputMemory(model):
    outputData={}
    for key in model.get_outputs().keys():
        outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])
    return outputData

def evaluate_embeddings_with_different_methods(embeddings, actual_issame, use_flipped_images, embedding_size):
    """用不同方法评估嵌入向量"""
    results = {}
    
    nrof_pairs = len(actual_issame)
    
    # 方法0: 原始方法（不使用翻转）
    if not use_flipped_images:
        tpr, fpr, accuracy, val, val_std, far = lfw.evaluate(
            embeddings,
            actual_issame,
            nrof_folds=10,
            distance_metric=1,
            subtract_mean=True
        )
        results["original"] = {
            "accuracy": np.mean(accuracy),
            "std": np.std(accuracy),
            "auc": metrics.auc(fpr, tpr)
        }
    
    # 方法1: 原始TF的拼接方式
    elif embeddings.shape[0] == nrof_pairs * 4:  # 有翻转图像
        # 方法1A: 原始 + 翻转
        final_embeddings = np.zeros((nrof_pairs * 2, embedding_size * 2))
        final_embeddings[:, :embedding_size] = embeddings[0::2]
        final_embeddings[:, embedding_size:] = embeddings[1::2]
        
        tpr, fpr, accuracy, val, val_std, far = lfw.evaluate(
            final_embeddings,
            actual_issame,
            nrof_folds=10,
            distance_metric=1,
            subtract_mean=True
        )
        results["original+flipped"] = {
            "accuracy": np.mean(accuracy),
            "std": np.std(accuracy),
            "auc": metrics.auc(fpr, tpr)
        }
    
    return results

def main_optimized(args):
    """优化后的主函数"""
    # 加载模型
    model = migraphx.load(args.migraphx_model_path)
    input_name = list(model.get_inputs().keys())[0]
    modelData=AllocateOutputMemory(model)
    embedding_size=512
    print("="*70)
    
    # 加载数据
    pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))
    paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs)
    nrof_pairs = len(actual_issame)
    
    # 准备所有图像路径和翻转标志
    all_image_paths = []
    flip_flags = []
    
    print("\nPreparing image paths...")
    for i in tqdm(range(nrof_pairs), desc="Organizing pairs"):
        path0 = paths[i*2]
        path1 = paths[i*2+1]
        
        # 第一张图像
        all_image_paths.append(path0)
        flip_flags.append(False)        
        if args.use_flipped_images:
            all_image_paths.append(path0)
            flip_flags.append(True)
        
        # 第二张图像
        all_image_paths.append(path1)
        flip_flags.append(False)        
        if args.use_flipped_images:
            all_image_paths.append(path1)
            flip_flags.append(True)
    
    nrof_images = len(all_image_paths)
    print(f"Total images to process: {nrof_images}")
    
    # 预分配嵌入向量存储
    all_embeddings = np.zeros((nrof_images, embedding_size), dtype=np.float32)
    
    # 推理
    print("\nRunning inference...")
    infer_times = []
    
    for start_idx in tqdm(range(0, nrof_images, args.lfw_batch_size), desc="Processing"):
        end_idx = min(start_idx + args.lfw_batch_size, nrof_images)
        batch_paths = all_image_paths[start_idx:end_idx]
        batch_flip_flags = flip_flags[start_idx:end_idx]
        
        # 预处理批次
        batch_images = []
        for img_path, flip_flag in zip(batch_paths, batch_flip_flags):
            # 使用PIL读取和预处理
            img = Image.open(img_path).convert('RGB')
            img = img.resize((args.image_size, args.image_size), Image.Resampling.BILINEAR)
            img_np = np.array(img, dtype=np.float32)
            
            if flip_flag:
                img_np = np.fliplr(img_np)
            
            # FaceNet标准化
            img_np = (img_np - 127.5) / 128.0
            
            # CHW格式
            img_np = np.transpose(img_np, (2, 0, 1))
            batch_images.append(img_np)
        
        batch_array = np.stack(batch_images, axis=0).astype(np.float32)
        
        # 为MIGraphX准备批次
        if batch_array.shape[0] < 64:
            pad_size = 64 - batch_array.shape[0]
            padding = np.repeat(batch_array[-1:], pad_size, axis=0)
            batch_for_infer = np.concatenate([batch_array, padding], axis=0)
        else:
            batch_for_infer = batch_array
        
        # 转换为MIGraphX参数
        batch_for_infer = np.ascontiguousarray(batch_for_infer)
        # mgx_arg = migraphx.to_gpu(migraphx.argument(batch_for_infer))
        # model_data = {input_name: mgx_arg}
        modelData[input_name] = migraphx.to_gpu(migraphx.argument(batch_for_infer))
        
        # 推理
        infer_start = time.time()
        output = model.run(modelData)
        infer_time = time.time() - infer_start
        infer_times.append(infer_time)
        
        # 提取嵌入向量
        embeddings_np = np.array(migraphx.from_gpu(output[0]))
        
        if batch_array.shape[0] < 64:
            embeddings_np = embeddings_np[:batch_array.shape[0]]
        
        all_embeddings[start_idx:end_idx] = embeddings_np
    
    print("\n" + "="*70)
    print("EVALUATION RESULTS")
    print("="*70)
    
    # 使用不同方法评估
    results = evaluate_embeddings_with_different_methods(
        all_embeddings, 
        actual_issame, 
        args.use_flipped_images, 
        embedding_size
    )
    
    # 打印结果
    print("\nComparison of different methods:")
    print("-"*70)
    for method_name, result in results.items():
        print(f"{method_name:20} | Accuracy: {result['accuracy']:.5f} ± {result['std']:.5f} | AUC: {result['auc']:.3f}")
    
    # 性能统计
    if infer_times:
        total_infer_time = sum(infer_times)
        avg_fps = nrof_images / total_infer_time
        print("\n" + "="*70)
        print("PERFORMANCE STATISTICS")
        print("-"*70)
        print(f"Total inference time: {total_infer_time:.3f}s")
        print(f"Average FPS: {avg_fps:.1f} images/s")
        print(f"Number of images: {nrof_images}")
        
        if args.use_flipped_images:
            print(f"  (Note: {nrof_pairs * 2} original images + their flips)")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--lfw_dir', type=str, default="/datasets/lfw_mtcnnpy_160")
    parser.add_argument('--lfw_batch_size', type=int, default=64)
    parser.add_argument('--migraphx_model_path', type=str, 
                       default="/home/sunzhq/workspace/yidong-infer/facenet/facenet/tools/onnx-models/facenet_static_bs64_fp32.mxr")
    parser.add_argument('--image_size', type=int, default=160)
    parser.add_argument('--lfw_pairs', type=str, default='data/pairs.txt')
    parser.add_argument('--use_flipped_images', action='store_true')
    parser.add_argument('--use_fixed_image_standardization', action='store_true')
    
    args = parser.parse_args()
    main_optimized(args)