import argparse
import argparse

import cv2
import numpy as np
import torch
import torchvision.datasets as datasets
import os
import torchvision.transforms as transforms
import time
import shutil
import migraphx

# class CIFAR100WithIndex(datasets.CIFAR100):
#     def __getitem__(self, index):
#         image, target = super().__getitem__(index)
#         # 生成伪文件名，例如 "cifar100_000123.png"
#         filename = f"cifar100_{index:06d}.png"
#         return image, target, filename

# create result dir
gpuid = os.getenv('HIP_VISIBLE_DEVICES')
resultdir = os.path.join('results', gpuid)
os.makedirs(f'{resultdir}/data', exist_ok=True)
os.makedirs(f'{resultdir}/label', exist_ok=True)

def AllocateOutputMemory(model):
    outputData={}
    for key in model.get_outputs().keys():
        outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])
    return outputData

use_cuda = torch.cuda.is_available()

class Resnet50:
    def __init__(self, onnx_model, gpu, dataset, workers, batch_size, confidence_thres, iou_thres, use_fp16, use_int8):
        """
        Args:
            onnx_model: Path to the ONNX model.
            input_image: Path to the input image.
            confidence_thres: Confidence threshold for filtering detections.
            iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
        """
        self.onnx_model = onnx_model
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres
        self.gpu = gpu
        self.dataset = dataset
        self.workers = workers
        self.batch_size = batch_size
        self.use_fp16 = use_fp16
        self.use_int8 = use_int8
    
    def main(self):
        """
        Performs inference using an ONNX model and returns the output image with drawn detections.

        Returns:
            output_img: The output image with drawn detections.
        """

        normalize = transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                          std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
        
        val_dataset = datasets.CIFAR100(args.dataset, train=False, 
                                        # download=True,
                                        transform=transforms.Compose([
                                                # transforms.Resize(32),
                                                # transforms.CenterCrop(224),
                                                transforms.ToTensor(),
                                                normalize,
        ]))
        # val_dataset = CIFAR100WithIndex(
        #                                     root=args.dataset,
        #                                     train=False,
        #                                     transform=transforms.Compose([
        #                                         transforms.ToTensor(),
        #                                         normalize,
        #                                     ])
        #                                 )
        
        val_loader = torch.utils.data.DataLoader(
                        val_dataset, 
                        batch_size=self.batch_size, shuffle=False,
                        num_workers=8, pin_memory=True, sampler=None)
        if os.path.isfile("/home/sunzhq/workspace/yidong/resnet/mmpretrain-main/resnet50_fp16.mxr"):
            model = migraphx.load("/home/sunzhq/workspace/yidong/resnet/mmpretrain-main/resnet50_fp16.mxr")
            inputName = model.get_parameter_names()[0]
        else:
            model = migraphx.parse_onnx(self.onnx_model)
            # 获取模型输入输出节点信息
            inputName=list(model.get_inputs().keys())[0]
            
            if self.use_fp16:
                migraphx.quantize_fp16(model)
            elif self.use_int8:
                for _, (images, _) in enumerate(val_loader):
                    images_q = np.array(images)
                    break
                calibrationData = [{inputName:migraphx.argument(images_q)}]
                migraphx.quantize_int8(model, migraphx.get_target("gpu"), calibrationData)
            
            model.compile(t=migraphx.get_target("gpu"), offload_copy=False, device_id=0) # device_id: 设置GPU设备，默认为0号设备
        modelData=AllocateOutputMemory(model)
        

        batch_time = AverageMeter('Time', ':6.3f')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        fps = AverageMeter('fps', ':6.2f')
        
        progress = ProgressMeter(
            len(val_loader),
            [batch_time, fps, top1, top5],
            prefix='Test: ')

        # save ground truth 
        gt = {}
        # Get the model inputs
        starts = time.time()
        start_time_1=time.time()
        infer_times = []
        total_infer_times = []
        total_start = time.time()

        # for i, (images, target, imgfilenames) in enumerate(val_loader):
        for i, (images, target) in enumerate(val_loader):
            # import pdb;pdb.set_trace()
            # Store the shape of the input for later use
            image_size_0 = images.size(0)
            images_1 = np.array(images)
            if i == 0:
                modelData[inputName] = migraphx.to_gpu(migraphx.argument(images_1))
                model.run(modelData)

            if image_size_0 == 24:
                modelData[inputName] = migraphx.to_gpu(migraphx.argument(images_1))
                start_time = time.time()
                start = time.time()
                # outputs = model.run({inputName:images_1})
                out = model.run(modelData)
                end_time = time.time()
                infer_times.append(time.time() - start)
                total_infer_times.append(time.time() - total_start)

                #outputs = torch.Tensor(np.array(outputs[0])).cuda()
                #save bin file
                # outputs = np.array(outputs[0])
                outputs = np.array(migraphx.from_gpu(out[0]))
                # for idx, fdata in enumerate(outputs):
                #     fdata.tofile(f'{resultdir}/data/{os.path.splitext(imgfilenames[idx])[0]}_0.bin')
                # gt.update(dict(zip(imgfilenames, target.numpy())))

                outputs = torch.Tensor(outputs).cuda()
                target = target.cuda()
                acc1, acc5 = accuracy(outputs, target, topk=(1, 5))
                batch_time.update(time.time() - start_time_1)
                fps.update(1/(end_time - start_time)*image_size_0)
                top1.update(acc1[0], image_size_0)
                top5.update(acc5[0], image_size_0)        
                progress.display(i)
                
                start_time_1=time.time()
            total_start = time.time()
        print("***************************")
        infer_time = sum(infer_times)
        avg_infer_fps = 24 * len(infer_times) / sum(infer_times)
        print(f"total_infer_time: {infer_time}s")
        print(f'avg_infer_fps: {avg_infer_fps}samples/s')
        load_data_infer_time = sum(total_infer_times)
        load_data_avg_infer_fps = len(total_infer_times) * 24 / sum(total_infer_times)
        print(f'load_data_total_infer_time: {load_data_infer_time}s')
        print(f'load_data_avg_total_Infer_fps: {load_data_avg_infer_fps} samples/s')
        print("******************************")

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                .format(top1=top1, top5=top5))
        print("总耗时：", time.time()-starts)
        # save ground truth file
        with open(f'{resultdir}/label/gt.txt', 'w') as f:
            for key, value in gt.items():
                f.write(f"{key} {value}\n")

def validate(val_loader, model, criterion, args):
        batch_time = AverageMeter('Time', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(
            len(val_loader),
            [batch_time, losses, top1, top5],
            prefix='Test: ')

        # switch to evaluate mode
        model.eval()

        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(val_loader):
                if torch.cuda.device_count() is not None:
                    images = images.cuda(args.gpu, non_blocking=True)
                    target = target.cuda(args.gpu, non_blocking=True)
                if use_cuda:
                    images = images.cuda()
                    target = target.cuda()
                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % 10 == 0:
                    progress.display(i)
            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                .format(top1=top1, top5=top5))

        return top1.avg

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))       
        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

if __name__ == "__main__":
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="/workspace/gaoruiqi_inference/cmcc-code/mmpretrain-main/resnet50.onnx", help="Input your ONNX model.")
    parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
    parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
    parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
    parser.add_argument('--dataset', default="/workspace/datasets/cifar100/", type=str, help='GPU id to use.')
    parser.add_argument('--workers', default=1, type=int)
    parser.add_argument('--batch_size', default=1, type=int, help='batch_size')
    parser.add_argument('--fp16', default=False, type=bool)
    parser.add_argument('--int8', default=False, type=bool)
    args = parser.parse_args()

    # Create an instance of the YOLOv8 class with the specified arguments
    detection = Resnet50(args.model, args.gpu, args.dataset, \
                         args.workers, args.batch_size, \
                         args.conf_thres, args.iou_thres, \
                         args.fp16, args.int8)

    # Perform object detection and obtain the output image
    output_image = detection.main()
