import os
import numpy as np
from PIL import Image
# import onnxruntime as ort
import argparse
import lfw 
import sys
from sklearn import metrics
from scipy.optimize import brentq
from scipy import interpolate
import time
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
import migraphx

gpuid = os.getenv('HIP_VISIBLE_DEVICES')
resultdir = os.path.join('results', gpuid)
os.makedirs(resultdir, exist_ok=True)

def AllocateOutputMemory(model):
    outputData={}
    for key in model.get_outputs().keys():
        outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])
    return outputData

def preprocess_image(image_path, target_size=(160, 160), flip=False):
   
    img = Image.open(image_path).convert('RGB')
    img = img.resize(target_size, Image.Resampling.BILINEAR)
    img_np = np.array(img, dtype=np.float32) 
    
    if flip:
        img_np = np.fliplr(img_np)
    
    img_np = (img_np - 127.5) / 128.0

    return img_np

def load_lfw_for_onnx(lfw_dir, pairs_file, batch_size, image_size=(160, 160), use_flipped_images=False):
    pairs = lfw.read_pairs(os.path.expanduser(pairs_file))
    paths, actual_issame = lfw.get_paths(os.path.expanduser(lfw_dir), pairs)

    nrof_pairs = len(actual_issame)
    all_processed_images = []
    all_labels = []
    all_actual_issame_full = []

    current_idx = 0
    
    for i in range(nrof_pairs):
        path0 = paths[i*2]
        path1 = paths[i*2+1]
        actual_same = actual_issame[i]

        # Process each image in the pair
        for img_path in [path0, path1]:
            processed_img = preprocess_image(
                img_path, 
                target_size=image_size, 
                flip=False
            )
            processed_img = np.transpose(processed_img, (2, 0, 1))
            processed_img = np.ascontiguousarray(processed_img)
            all_processed_images.append(processed_img)
            all_labels.append(current_idx)  # 使用递增的索引作为标签
            all_actual_issame_full.append(actual_same) 
            
            current_idx += 1
            
            if use_flipped_images:
                processed_img_flipped = preprocess_image(
                    img_path, 
                    target_size=image_size, 
                    flip=True
                )
                processed_img_flipped = np.transpose(processed_img_flipped, (2, 0, 1))
                processed_img_flipped = np.ascontiguousarray(processed_img_flipped)
                all_processed_images.append(processed_img_flipped)
                all_labels.append(current_idx)  # 使用递增的索引作为标签
                all_actual_issame_full.append(actual_same) 
                
                current_idx += 1
    
    # 转换为numpy数组
    all_labels = np.array(all_labels, dtype=np.int32)

    num_batches = len(all_processed_images) // batch_size
    if len(all_processed_images) % batch_size != 0:
        print(f"Warning: Number of images ({len(all_processed_images)}) is not evenly divisible by batch size ({batch_size}). Last batch will be smaller.")
        num_batches += 1

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, len(all_processed_images))

        batch_images = all_processed_images[start_idx:end_idx]
        batch_labels = all_labels[start_idx:end_idx]
        batch_actual_issame_part = all_actual_issame_full[start_idx:end_idx]

        batch_array = np.stack(batch_images, axis=0)
        batch_array = batch_array.astype(np.float32)

        yield batch_array, batch_labels, batch_actual_issame_part

def main_onnx(args):
    embedding_size = 512
    pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))
    paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs)
    nrof_pairs = len(actual_issame)
    
    if args.use_flipped_images:
        nrof_images = nrof_pairs * 4  # 每对图像有4张：A原图、A翻转、B原图、B翻转
    else:
        nrof_images = nrof_pairs * 2  # 每对图像有2张：A原图、B原图
    
    print(f"Number of pairs: {nrof_pairs}, Number of images: {nrof_images}, Embedding size: {embedding_size}")

    
    data_generator = load_lfw_for_onnx(
        args.lfw_dir, 
        args.lfw_pairs, 
        args.lfw_batch_size, 
        image_size=(args.image_size, args.image_size),
        use_flipped_images=args.use_flipped_images
    )

    total_images_calculated = nrof_images
    num_batches_calculated = total_images_calculated // args.lfw_batch_size
    if total_images_calculated % args.lfw_batch_size != 0:
        num_batches_calculated += 1
    all_embeddings = np.zeros((nrof_images, embedding_size), dtype=np.float32)
    current_image_index = 0
    
    model = migraphx.load(args.model_path)
    inputName=list(model.get_inputs().keys())[0]
    modelData=AllocateOutputMemory(model)
    
    # warm up
    modelData[inputName] = migraphx.to_gpu(migraphx.argument(np.ones([64,3,160,160]).astype(np.float32)))
    model.run(modelData)
    
    infer_times = []
    total_infer_times = []
    total_start = time.time()
    for i, (batch_images, batch_label, _) in enumerate(tqdm(data_generator, total=num_batches_calculated, desc="Processing Batches")):
        original_batch_size = batch_images.shape[0]
        if original_batch_size < 64:
            pad_size = 64 - original_batch_size
            padding_images = np.repeat(batch_images[-1:], pad_size, axis=0)
            batch_images = np.concatenate((batch_images, padding_images), axis=0)
        modelData[inputName] = migraphx.to_gpu(migraphx.argument(batch_images))
           
        start = time.time()
        embeddings_dcu = model.run(modelData)
        infer_time_taken = time.time() - start
        embeddings = np.array(migraphx.from_gpu(embeddings_dcu[0]))
        infer_times.append(infer_time_taken)
        
        if original_batch_size != 64:
            embeddings = embeddings[:original_batch_size]
            # print(f"**********************************embeddings: {embeddings.shape}")
        batch_size_current = embeddings.shape[0]
        all_embeddings[current_image_index:current_image_index + batch_size_current] = embeddings
        current_image_index += batch_size_current
        
        embeddings.tofile(os.path.join(f'{resultdir}', '{}_0.bin'.format(str(i).zfill(6))))
        batch_label.tofile(os.path.join(f'{resultdir}', '{}.bin'.format(str(i).zfill(6))))
        if i % 10 == 9:
            print('.', end='')
            sys.stdout.flush()
        total_infer_times.append(time.time() - total_start)
        total_start = time.time()
    print("\nAll batches processed.")

    # 验证嵌入向量数量是否正确
    print(f"Total embeddings collected: {current_image_index}")
    print(f"Expected embeddings: {nrof_images}")
    
    if current_image_index != nrof_images:
        print(f"Warning: Expected {nrof_images} embeddings but collected {current_image_index}")

    if args.use_flipped_images:
        # 使用翻转图像时，每张图像有两个嵌入向量（原图和翻转图）,将它们合并成一个增强的嵌入向量
        nrof_original_images = nrof_pairs * 2
        final_embeddings = np.zeros((nrof_original_images, embedding_size * 2), dtype=np.float32)
        
        # 将原图和翻转图的嵌入向量拼接起来
        for i in range(nrof_original_images):
            final_embeddings[i, :embedding_size] = all_embeddings[i*2]
            final_embeddings[i, embedding_size:] = all_embeddings[i*2+1]
    else:
        final_embeddings = all_embeddings

    distance_metric = 1 # Euclidean
    subtract_mean = True
    nrof_folds = 10    
    tpr, fpr, accuracy, val, val_std, far = lfw.evaluate(
        final_embeddings, 
        actual_issame, 
        nrof_folds=nrof_folds, 
        distance_metric=distance_metric, 
        subtract_mean=subtract_mean
    )
    
    print('Accuracy: %2.5f+-%2.5f' % (np.mean(accuracy), np.std(accuracy)))
    print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far))
    
    auc = metrics.auc(fpr, tpr)
    print('Area Under Curve (AUC): %1.3f' % auc)
    eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.)
    print('Equal Error Rate (EER): %1.3f' % eer)
    
    # 额外打印使用翻转图像的配置信息
    if args.use_flipped_images:
        print(f"Configuration: Using flipped images (original + flipped concatenated)")
        print(f"Embedding dimension: {embedding_size*2}")
    
    print("***************************")
    
    infer_time = sum(infer_times)
    avg_infer_fps = 64 * len(infer_times) / sum(infer_times)
    print(f"total_infer_time: {infer_time}s")
    print(f'avg_infer_fps: {avg_infer_fps}samples/s')
    load_data_infer_time = sum(total_infer_times)
    load_data_avg_infer_fps = len(total_infer_times) * 64 / sum(total_infer_times)
    print(f'load_data_total_infer_time: {load_data_infer_time}s')
    print(f'load_data_avg_total_Infer_fps: {load_data_avg_infer_fps} samples/s')
    print("******************************")

def parse_arguments_onnx():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lfw_dir', type=str, default="/datasets/lfw_mtcnnpy_160", help='Path to the data directory containing aligned LFW face patches.')
    parser.add_argument('--lfw_batch_size', type=int, help='Number of images to process in a batch in the LFW test set.', default=64) # Changed default to common ONNX batch size
    parser.add_argument('--model_path', type=str, default="/home/sunzhq/workspace/yidong-infer/facenet/facenet/tools/onnx-models/facenet_static_bs64.mxr", help='Path to the ONNX model file.')
    parser.add_argument('--image_size', type=int, help='Image size (height, width) in pixels.', default=160)
    parser.add_argument('--lfw_pairs', type=str, help='The file containing the pairs to use for validation.', default='data/pairs.txt')
    parser.add_argument('--use_flipped_images', action='store_true', help='Use flipped images for evaluation (original + flipped concatenated).')
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_arguments_onnx()
    main_onnx(args)