#!/usr/bin/env python3
""" ImageNet Validation Script

This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.

Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import os
import csv
import glob
import json
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
import migraphx
import numpy as np

from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
    decay_batch_step, check_batch_size_retry

has_apex = False
try:
    from apex import amp
    has_apex = True
except ImportError:
    pass

has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass

try:
    from functorch.compile import memory_efficient_fusion
    has_functorch = True
except ImportError as e:
    has_functorch = False

torch.backends.cudnn.benchmark = True

def AllocateOutputMemory(model):
    outputData={}
    for key in model.get_outputs().keys():
        outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])
    return outputData

parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
                    help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
                    help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
                    help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
                    help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
                    metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
                    help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float,
                    metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
                    help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                    help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
                    metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
                    help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
                    help='enable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
                    help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
                    help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
                    help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
                    help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
                    help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
                    help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
                    help='use ema version of weights if present')
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
                    help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
                    help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str,
                    help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
                    help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
                    help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
                    help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--retry', default=False, action='store_true',
                    help='Enable batch size decay & retry for single model validation')

# create result dir
gpuid = os.getenv('HIP_VISIBLE_DEVICES')
resultdir = os.path.join('results', gpuid)
os.makedirs(resultdir, exist_ok=True)

def migraphx_efficient(model, data_tensor):
    # 将输入的tensor数据转换为numpy
    data_numpy=data_tensor.detach().cpu().numpy()
    device = torch.device("cuda")

    # 注意：这里需要执行赋值操作，否则会造成migraphx中输入数据步长不对
    img_data = np.zeros(data_numpy.shape).astype("float32")
    for i in range(data_numpy.shape[0]):
        img_data[i, :, :, :] = data_numpy[i, :, :, :]

    # 执行推理
    result = model.run({"input": img_data})

    # 将结果转换为tensor
    result0=torch.from_numpy(np.array(result[0], copy=False)).to(device)

    return result0

def validate(args, model, inputName):
    # might as well try to validate something
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing

    data_config = {'input_size': (3, 288, 288), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 1.0}
    
    dataset = create_dataset(root=args.data, name=args.dataset, split=args.split, download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(1000)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = data_config['crop_pct']

    # 创建数据集
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    # top1 = AverageMeter()
    # top5 = AverageMeter()
    modelData=AllocateOutputMemory(model)
    inputName = model.get_parameter_names()[0]
    val_label = {}
    with torch.no_grad():
        # warmup
        # image = np.ones([24,3,288,288]).astype(np.float32)
        # results = model.run({inputName:image})                # 推理结果，list类型
        modelData[inputName] = migraphx.to_gpu(migraphx.argument(np.ones([24,3,288,288]).astype(np.float32)))
        model.run(modelData)
        
        # 执行推理测试
        infer_times = []
        total_infer_times = []
        end1 = time.time()
        total_start = time.time()
        total = 0
        correct = 0
        for batch_idx, (input, target, imgfiles) in enumerate(loader):
            nb, _, height, width = input.shape  # batch size, channels, height, width

            if nb != 24:
                break

            # 执行推理
            with amp_autocast():
                data_numpy=input.detach().cpu().numpy()
                device = torch.device("cuda")

                # 注意：这里需要执行赋值操作，否则会造成migraphx中输入数据步长不对
                # img_data = np.zeros(data_numpy.shape).astype("float32")
                # for i in range(data_numpy.shape[0]):
                #     img_data[i, :, :, :] = data_numpy[i, :, :, :]
                img_data = data_numpy.astype(np.float32)
                modelData[inputName] = migraphx.to_gpu(migraphx.argument(img_data))
                # 执行推理
                end = time.time()
                start = time.time()
                # result = model.run({"input": img_data})
                result_dcu = model.run(modelData)
                infer_times.append(time.time() - start)

                batch_time.update(time.time() - end)
                result = np.array(migraphx.from_gpu(result_dcu[0]))
                # 将结果转换为tensor
                output = torch.from_numpy(np.array(result, copy=False)).to(device)
        

            
            # 保存结果文件
            odata = output.cpu().numpy()
            total_infer_times.append(time.time() - total_start)

            #pred = torch.argmax(probabilities, 1)
            #correct += torch.sum(pred[:len(nb)] == torch.tensor(nb))
            #total += nb
 
            
            imgfiles = [os.path.basename(b) for b in imgfiles]
            for idx, imgfile in enumerate(imgfiles):
                f = os.path.splitext(imgfile)[0]
                odata[idx].tofile(f'{resultdir}/{f}_0.bin')
                bin_data = odata[idx]
                this_target = target[idx]
                this_target = [this_target]
                bin_data = bin_data.reshape(-1, 1000)
                probabilities = torch.nn.functional.softmax(torch.tensor(bin_data), dim=1)
                pred = torch.argmax(probabilities, 1)
                correct += torch.sum(pred[:len(this_target)] == torch.tensor(this_target))
                total += len(this_target)
                #print(correct)
                #print(total)
            val_label.update(dict(zip(imgfiles, target.cpu().numpy())))
            total_start = time.time()
        acc1 = float(correct / total)
        print("总体耗时：", time.time()-end1)
    results = OrderedDict(
        model=args.model,
        top1=round(acc1, 6), top1_err=round(1 - acc1, 6),
        param_count=round(13649388 / 1e6, 2),
        img_size=data_config['input_size'][-1],
        crop_pct=crop_pct,
        interpolation=data_config['interpolation'])

    
    with open(f'{resultdir}/val_label.txt', 'w') as file:
        for key, value in val_label.items():
            file.write(f"{key} {value}\n")
    
    print("适应本次测试的指标如下所示")
    print("***************************")
    infer_time = sum(infer_times)
    avg_infer_time = 24 * len(infer_times) / sum(infer_times)
    print(f"total_infer_time: {infer_time}s")
    print(f'avg_infer_fps: {avg_infer_time}samples/s')
    load_data_infer_time = sum(total_infer_times)
    load_data_avg_infer_time = len(total_infer_times) * 24 / sum(total_infer_times)
    print(f'load_data_total_infer_time: {load_data_infer_time}s')
    print(f'load_data_avg_total_Infer_fps: {load_data_avg_infer_time} samples/s')

    return results

def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []

    if os.path.isfile("/home/sunzhq/workspace/yidong/efficientnet/efficientNet_imagemodels/efficient_b2_fp16.mxr"):
        model = migraphx.load("/home/sunzhq/workspace/yidong/efficientnet/efficientNet_imagemodels/efficient_b2_fp16.mxr")
        inputName = model.get_parameter_names()[0]
    else:
        # 加载onnx模型
        model = migraphx.parse_onnx(args.model)
        inputName = model.get_parameter_names()[0]
        # use fp16
        migraphx.quantize_fp16(model)
        # 编译
        model.compile(t=migraphx.get_target("gpu"), offload_copy=False, device_id=0)

    results = validate(args, model, inputName)

    # output results in JSON to stdout w/ delimiter for runner script
    print(f'--result\n{json.dumps(results, indent=4)}')


def write_results(results_file, results):
    with open(results_file, mode='w') as cf:
        dw = csv.DictWriter(cf, fieldnames=results[0].keys())
        dw.writeheader()
        for r in results:
            dw.writerow(r)
        cf.flush()


if __name__ == '__main__':
    main()
