#include <resnet50.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <opencv2/dnn.hpp>
#include <CommonUtility.h>
#include <Filesystem.h>
#include <SimpleLog.h>

namespace migraphxSamples
{

Classifier::Classifier()
{
}

Classifier::~Classifier()
{
    configurationFile.release();
}

ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializationParameterOfClassifier)
{
    // 读取配置文件
    std::string configFilePath=CONFIG_FILE;
    if(Exists(configFilePath)==false)
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, FileStorage::READ))
    {
       LOG_ERROR(stdout, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");

    // 获取配置文件参数
    FileNode netNode = configurationFile["Classifier"];
    std::string modelPath=initializationParameter.parentPath+(std::string)netNode["ModelPath"];
    scale=(float)netNode["Scale"];
    meanValue.val[0]=(float)netNode["MeanValue1"];
    meanValue.val[1]=(float)netNode["MeanValue2"];
    meanValue.val[2]=(float)netNode["MeanValue3"];
    swapRB=(bool)(int)netNode["SwapRB"];
    crop=(bool)(int)netNode["Crop"];
    useInt8=(bool)(int)netNode["UseInt8"];
    useFP16=(bool)(int)netNode["UseFP16"];

    // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    net = migraphx::parse_onnx(modelPath);
    LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

    // 获取模型输入属性
    std::unordered_map<std::string, migraphx::shape> inputMap=net.get_parameter_shapes();
    inputName=inputMap.begin()->first;
    inputShape=inputMap.begin()->second;
    int N=inputShape.lens()[0];
    int C=inputShape.lens()[1];
    int H=inputShape.lens()[2];
    int W=inputShape.lens()[3];
    inputSize=cv::Size(W,H);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 量化
    if(useInt8)
    {
        // 创建量化校准数据,建议使用测试集中的多张典型图像
        cv::Mat srcImage=cv::imread("../Resource/Images/ImageNet_01.jpg",1);
        std::vector<cv::Mat> srcImages;
        for(int i=0;i<inputShape.lens()[0];++i)
        {
            srcImages.push_back(srcImage);
        }
        cv::Mat inputBlob;
        cv::dnn::blobFromImages(srcImages,
                        inputBlob,
                        scale,
                        inputSize,
                        meanValue,
                        swapRB,
                        false);
        std::unordered_map<std::string, migraphx::argument> inputData;
        inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};
        std::vector<std::unordered_map<std::string, migraphx::argument>> calibrationData = {inputData};

        // INT8量化
        migraphx::quantize_int8(net, gpuTarget, calibrationData);
    }
    if(useFP16)
    {
        migraphx::quantize_fp16(net);
    }

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                // 设置GPU设备，默认为0号设备
    options.offload_copy=true; 
    net.compile(gpuTarget,options);
    LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]=migraphx::argument{inputShape};
    net.eval(inputData);

    // log
    LOG_INFO(logFile,"InputSize:%dx%d\n",inputSize.width,inputSize.height);
    LOG_INFO(logFile,"InputName:%s\n",inputName.c_str());
    LOG_INFO(logFile,"Scale:%.6f\n",scale);
    LOG_INFO(logFile,"Mean:%.2f,%.2f,%.2f\n",meanValue.val[0],meanValue.val[1],meanValue.val[2]);
    LOG_INFO(logFile,"SwapRB:%d\n",(int)swapRB);
    LOG_INFO(logFile,"Crop:%d\n",(int)crop);
    LOG_INFO(logFile,"UseInt8:%d\n",(int)useInt8);
    LOG_INFO(logFile,"UseFP16:%d\n",(int)useFP16);
    
    return SUCCESS;

}

ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector<std::vector<ResultOfPrediction>> &predictions)
{
    if(srcImages.size()==0||srcImages[0].empty()||srcImages[0].depth()!=CV_8U)
    {
        LOG_ERROR(logFile, "image error!\n");
        return IMAGE_ERROR;
    }
    
    // 预处理并转换为NCHW
    cv::Mat inputBlob;
    cv::dnn::blobFromImages(srcImages,
                    inputBlob,
                    scale,
                    inputSize,
                    meanValue,
                    swapRB,
                    false);

    // 输入数据
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> results = net.eval(inputData);

    // 获取输出节点的属性
    migraphx::argument result  = results[0];                // 获取第一个输出节点的数据
    migraphx::shape outputShape=result.get_shape();         // 输出节点的shape
    std::vector<std::size_t> outputSize=outputShape.lens(); // 每一维大小，维度顺序为(N,C,H,W) 
    int numberOfOutput=outputShape.elements();              // 输出节点元素的个数
    float *logits=(float *)result.data();                   // 输出节点数据指针

    // 获取每张图像的预测结果
    int numberOfClasses=numberOfOutput/srcImages.size();
    for(int i=0;i<srcImages.size();++i)
    {
        int startIndex=numberOfClasses*i;

        // 获取每幅图像对应的输出
        std::vector<float> logit;
        for(int j=0;j<numberOfClasses;++j)
        {
            logit.push_back(logits[startIndex+j]);
        }
        
        std::vector<ResultOfPrediction> resultOfPredictions;
        for(int j=0;j<numberOfClasses;++j)
        {
            ResultOfPrediction prediction;
            prediction.label=j;
            prediction.confidence=logit[j];

            resultOfPredictions.push_back(prediction);
        }

        predictions.push_back(resultOfPredictions);
    }

    return SUCCESS;

}

}
