#include <YOLOV5.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <Filesystem.h>
#include <SimpleLog.h>


namespace migraphxSamples
{

DetectorYOLOV5::DetectorYOLOV5()
{

}

DetectorYOLOV5::~DetectorYOLOV5()
{

    configurationFile.release();
    
}

ErrorCode DetectorYOLOV5::Initialize(InitializationParameterOfDetector initializationParameterOfDetector, bool dynamic)
{
    // 读取配置文件
    std::string configFilePath=initializationParameterOfDetector.configFilePath;
    if(Exists(configFilePath)==false)
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, cv::FileStorage::READ))
    {
       LOG_ERROR(stdout, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");
    
    // 获取配置文件参数
    cv::FileNode netNode = configurationFile["DetectorYOLOV5"];
    if(dynamic)
    {
        modelPath=(std::string)netNode["ModelPathDynamic"];
    }
    else
    {
        modelPath=(std::string)netNode["ModelPathStatic"];
    }
    std::string pathOfClassNameFile=(std::string)netNode["ClassNameFile"];
    yolov5Parameter.confidenceThreshold = (float)netNode["ConfidenceThreshold"];
    yolov5Parameter.nmsThreshold = (float)netNode["NMSThreshold"];
    yolov5Parameter.objectThreshold = (float)netNode["ObjectThreshold"];
    yolov5Parameter.numberOfClasses=(int)netNode["NumberOfClasses"];
    useFP16=(bool)(int)netNode["UseFP16"];

    if(dynamic)
    {
        // 加载模型
        if(Exists(modelPath)==false)
        {
            LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
            return MODEL_NOT_EXIST;
        }
        
        migraphx::onnx_options onnx_options;
        onnx_options.map_input_dims["images"]={1,3,800,800};
        net = migraphx::parse_onnx(modelPath, onnx_options);
        LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

        // 获取模型输入/输出节点信息
        std::cout<<"inputs:"<<std::endl;
        std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
        for(auto i:inputs)
        {
            std::cout<<i.first<<":"<<i.second<<std::endl;
        }
        std::cout<<"outputs:"<<std::endl;
        std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
        for(auto i:outputs)
        {
            std::cout<<i.first<<":"<<i.second<<std::endl;
        }
        inputName=inputs.begin()->first;
        inputShape=inputs.begin()->second;
        int N=inputShape.lens()[0];
        int C=inputShape.lens()[1];
        int H=inputShape.lens()[2];
        int W=inputShape.lens()[3];
        inputSize=cv::Size(W,H);

        // log
        LOG_INFO(stdout,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height);
    }
    else
    {
        // 加载模型
        if(Exists(modelPath)==false)
        {
            LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
            return MODEL_NOT_EXIST;
        }
        net = migraphx::parse_onnx(modelPath);
        LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

        // 获取模型输入/输出节点信息
        std::cout<<"inputs:"<<std::endl;
        std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
        for(auto i:inputs)
        {
            std::cout<<i.first<<":"<<i.second<<std::endl;
        }
        std::cout<<"outputs:"<<std::endl;
        std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
        for(auto i:outputs)
        {
            std::cout<<i.first<<":"<<i.second<<std::endl;
        }
        inputName=inputs.begin()->first;
        inputShape=inputs.begin()->second;
        int N=inputShape.lens()[0];
        int C=inputShape.lens()[1];
        int H=inputShape.lens()[2];
        int W=inputShape.lens()[3];
        inputSize=cv::Size(W,H);

        // log
        LOG_INFO(stdout,"InputSize:%dx%d\n",inputSize.width,inputSize.height);
    }

    LOG_INFO(stdout,"InputName:%s\n",inputName.c_str());
    LOG_INFO(stdout,"ConfidenceThreshold:%f\n",yolov5Parameter.confidenceThreshold);
    LOG_INFO(stdout,"NMSThreshold:%f\n",yolov5Parameter.nmsThreshold);
    LOG_INFO(stdout,"objectThreshold:%f\n",yolov5Parameter.objectThreshold);
    LOG_INFO(stdout,"NumberOfClasses:%d\n",yolov5Parameter.numberOfClasses);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 量化    
    if(useFP16)
    {
        migraphx::quantize_fp16(net);
    }

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0; 
    options.offload_copy=true;
    net.compile(gpuTarget,options);
    LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // warm up
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]=migraphx::argument{inputShape};
    net.eval(inputData);

    // 读取类别名
    if(!pathOfClassNameFile.empty())
    {
        std::ifstream classNameFile(pathOfClassNameFile);
        std::string line;
        while (getline(classNameFile, line))
        {
            classNames.push_back(line);
        }
    }
    else
    {
        classNames.resize(yolov5Parameter.numberOfClasses);
    }


    return SUCCESS;

}

ErrorCode DetectorYOLOV5::Detect(const cv::Mat &srcImage, std::vector<std::size_t> &relInputShape, std::vector<ResultOfDetection> &resultsOfDetection, bool dynamic)
{
    if(srcImage.empty()||srcImage.type()!=CV_8UC3)
    {
        LOG_ERROR(stdout, "image error!\n");
        return IMAGE_ERROR;
    }

    // 数据预处理并转换为NCHW格式
    inputSize = cv::Size(relInputShape[3], relInputShape[2]);
    cv::Mat inputBlob;
    cv::dnn::blobFromImage(srcImage,
                    inputBlob,
                    1 / 255.0,
                    inputSize,
                    cv::Scalar(0, 0, 0),
                    true,
                    false);

    // 创建输入数据
    migraphx::parameter_map inputData;
    if(dynamic)
    {
        inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(), relInputShape), (float*)inputBlob.data};
    }
    else
    {
        inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};
    }
    

    // 推理
    std::vector<migraphx::argument> inferenceResults = net.eval(inputData);

    // 获取推理结果
    std::vector<cv::Mat> outs;
    migraphx::argument result = inferenceResults[0]; 

    // 转换为cv::Mat
    migraphx::shape outputShape = result.get_shape();
    int shape[]={outputShape.lens()[0],outputShape.lens()[1],outputShape.lens()[2]};
    cv::Mat out(3,shape,CV_32F);
    memcpy(out.data,result.data(),sizeof(float)*outputShape.elements());
    outs.push_back(out);

    //获取先验框的个数
    int numProposal = outs[0].size[1];
    int numOut = outs[0].size[2];
    //变换输出的维度
    outs[0] = outs[0].reshape(0, numProposal);

    //生成先验框
    std::vector<float> confidences;
    std::vector<cv::Rect> boxes;
    std::vector<int> classIds;
    float ratioh = (float)srcImage.rows / inputSize.height, ratiow = (float)srcImage.cols / inputSize.width;

    //计算cx,cy,w,h,box_sore,class_sore
    int n = 0, rowInd = 0;
    float* pdata = (float*)outs[0].data;
    for (n = 0; n < numProposal; n++)
    {
        float boxScores = pdata[4];
        if (boxScores > yolov5Parameter.objectThreshold)
        {
            cv::Mat scores = outs[0].row(rowInd).colRange(5, numOut);
            cv::Point classIdPoint;
            double maxClassScore;
            cv::minMaxLoc(scores, 0, &maxClassScore, 0, &classIdPoint);
            maxClassScore *= boxScores;
            if (maxClassScore > yolov5Parameter.confidenceThreshold)
            {
                const int classIdx = classIdPoint.x;
                float cx = pdata[0] * ratiow;
                float cy = pdata[1] * ratioh;
                float w = pdata[2] * ratiow;
                float h = pdata[3] * ratioh;

                int left = int(cx - 0.5 * w);
                int top = int(cy - 0.5 * h);

                confidences.push_back((float)maxClassScore);
                boxes.push_back(cv::Rect(left, top, (int)(w), (int)(h)));
                classIds.push_back(classIdx);
            }
        }
        rowInd++;
        pdata += numOut;
    }

    //执行non maximum suppression消除冗余重叠boxes
    std::vector<int> indices;
    cv::dnn::NMSBoxes(boxes, confidences, yolov5Parameter.confidenceThreshold, yolov5Parameter.nmsThreshold, indices);
    for (size_t i = 0; i < indices.size(); ++i)
    {
        int idx = indices[i];
        int classID=classIds[idx];
        string className=classNames[classID];
        float confidence=confidences[idx];
        cv::Rect box = boxes[idx];

        ResultOfDetection result;
        result.boundingBox=box;
        result.confidence=confidence;// confidence
        result.classID=classID; // label
        result.className=className;
        resultsOfDetection.push_back(result);
    }

    return SUCCESS;
}

}
