#include <DetectorYOLOV5.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <opencv2/dnn.hpp>
#include <CommonUtility.h>
#include <Filesystem.h>
#include <SimpleLog.h>

using namespace cv::dnn;

namespace migraphxSamples
{

DetectorYOLOV5::DetectorYOLOV5():logFile(NULL)
{

}

DetectorYOLOV5::~DetectorYOLOV5()
{

    configurationFile.release();
    
}

ErrorCode DetectorYOLOV5::Initialize(InitializationParameterOfDetector initializationParameterOfDetector)
{
    // 初始化(获取日志文件,加载配置文件等)
    ErrorCode errorCode=DoCommonInitialization(initializationParameterOfDetector);
    if(errorCode!=SUCCESS)
    {
        LOG_ERROR(logFile,"fail to DoCommonInitialization\n");
        return errorCode;
    }
    LOG_INFO(logFile,"succeed to DoCommonInitialization\n");
    
    // 获取配置文件参数
    FileNode netNode = configurationFile["DetectorYOLOV5"];
    string modelPath=initializationParameter.parentPath+(string)netNode["ModelPath"];
    string pathOfClassNameFile=(string)netNode["ClassNameFile"];
    yolov5Parameter.confidenceThreshold = (float)netNode["ConfidenceThreshold"];
    yolov5Parameter.nmsThreshold = (float)netNode["NMSThreshold"];
    yolov5Parameter.objectThreshold = (float)netNode["ObjectThreshold"];
    yolov5Parameter.numberOfClasses=(int)netNode["NumberOfClasses"];
    useFP16=(bool)(int)netNode["UseFP16"];

    // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    net = migraphx::parse_onnx(modelPath);
    LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

    // 获取模型输入属性
    std::pair<std::string, migraphx::shape> inputAttribute=*(net.get_parameter_shapes().begin());
    inputName=inputAttribute.first;
    inputShape=inputAttribute.second;
    inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 量化    
    if(useFP16)
    {
        migraphx::quantize_fp16(net);
    }

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0; // 设置GPU设备，默认为0号设备
    options.offload_copy=true; // 设置offload_copy
    net.compile(gpuTarget,options);
    LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // Run once by itself
    migraphx::parameter_map inputData;
    inputData[inputName]=migraphx::generate_argument(inputShape);
    net.eval(inputData);

    // 读取类别名
    if(!pathOfClassNameFile.empty())
    {
        ifstream classNameFile(pathOfClassNameFile);
        string line;
        while (getline(classNameFile, line))
        {
            classNames.push_back(line);
        }
    }
    else
    {
        classNames.resize(yolov5Parameter.numberOfClasses);
    }

    // log
    LOG_INFO(logFile,"InputSize:%dx%d\n",inputSize.width,inputSize.height);
    LOG_INFO(logFile,"InputName:%s\n",inputName.c_str());
    LOG_INFO(logFile,"ConfidenceThreshold:%f\n",yolov5Parameter.confidenceThreshold);
    LOG_INFO(logFile,"NMSThreshold:%f\n",yolov5Parameter.nmsThreshold);
    LOG_INFO(logFile,"objectThreshold:%f\n",yolov5Parameter.objectThreshold);
    LOG_INFO(logFile,"NumberOfClasses:%d\n",yolov5Parameter.numberOfClasses);

    return SUCCESS;

}

ErrorCode DetectorYOLOV5::Detect(const cv::Mat &srcImage, std::vector<ResultOfDetection> &resultsOfDetection)
{
    if(srcImage.empty()||srcImage.type()!=CV_8UC3)
    {
        LOG_ERROR(logFile, "image error!\n");
        return IMAGE_ERROR;
    }

   // 预处理并转换为NCHW
    cv::Mat inputBlob;
    blobFromImage(srcImage,
                    inputBlob,
                    1 / 255.0,
                    inputSize,
                    Scalar(0, 0, 0),
                    true,
                    false);
    // 输入数据
    migraphx::parameter_map inputData;
    inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> inferenceResults = net.eval(inputData);

    // 获取推理结果
    std::vector<cv::Mat> outs;
    migraphx::argument result = inferenceResults[0]; 

    // 转换为cv::Mat
    migraphx::shape outputShape = result.get_shape();
    int shape[]={outputShape.lens()[0],outputShape.lens()[1],outputShape.lens()[2]};
    cv::Mat out(3,shape,CV_32F);
    memcpy(out.data,result.data(),sizeof(float)*outputShape.elements());
    outs.push_back(out);

    //获取先验框的个数
    int numProposal = outs[0].size[1];
    int numOut = outs[0].size[2];
    //变换输出的维度
    outs[0] = outs[0].reshape(0, numProposal);

    //生成先验框
    std::vector<float> confidences;
    std::vector<cv::Rect> boxes;
    std::vector<int> classIds;
    float ratioh = (float)srcImage.rows / inputSize.height, ratiow = (float)srcImage.cols / inputSize.width;

    //计算cx,cy,w,h,box_sore,class_sore
    int n = 0, rowInd = 0;
    float* pdata = (float*)outs[0].data;
    for (n = 0; n < numProposal; n++)
    {
        float boxScores = pdata[4];
        if (boxScores > yolov5Parameter.objectThreshold)
        {
            cv::Mat scores = outs[0].row(rowInd).colRange(5, numOut);
            cv::Point classIdPoint;
            double maxClassScore;
            cv::minMaxLoc(scores, 0, &maxClassScore, 0, &classIdPoint);
            maxClassScore *= boxScores;
            if (maxClassScore > yolov5Parameter.confidenceThreshold)
            {
                const int classIdx = classIdPoint.x;
                float cx = pdata[0] * ratiow;
                float cy = pdata[1] * ratioh;
                float w = pdata[2] * ratiow;
                float h = pdata[3] * ratioh;

                int left = int(cx - 0.5 * w);
                int top = int(cy - 0.5 * h);

                confidences.push_back((float)maxClassScore);
                boxes.push_back(cv::Rect(left, top, (int)(w), (int)(h)));
                classIds.push_back(classIdx);
            }
        }
        rowInd++;
        pdata += numOut;
    }

    //执行non maximum suppression消除冗余重叠boxes
    std::vector<int> indices;
    dnn::NMSBoxes(boxes, confidences, yolov5Parameter.confidenceThreshold, yolov5Parameter.nmsThreshold, indices);
    for (size_t i = 0; i < indices.size(); ++i)
    {
        int idx = indices[i];
        int classID=classIds[idx];
        string className=classNames[classID];
        float confidence=confidences[idx];
        cv::Rect box = boxes[idx];

        ResultOfDetection result;
        result.boundingBox=box;
        result.confidence=confidence;// confidence
        result.classID=classID; // label
        result.className=className;
        resultsOfDetection.push_back(result);
    }

    return SUCCESS;
}

ErrorCode DetectorYOLOV5::DoCommonInitialization(InitializationParameterOfDetector initializationParameterOfDetector)
{
    initializationParameter=initializationParameterOfDetector;

    // 获取日志文件
    logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName);

    // 加载配置文件
    std::string configFilePath=initializationParameter.configFilePath;
    if(!Exists(configFilePath))
    {
        LOG_ERROR(logFile, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, FileStorage::READ))
    {
       LOG_ERROR(logFile, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(logFile, "succeed to open configuration file\n");

    // 修改父路径
    std::string &parentPath = initializationParameter.parentPath;
    if (!parentPath.empty())
    {
        if(!IsPathSeparator(parentPath[parentPath.size() - 1]))
        {
           parentPath+=PATH_SEPARATOR;
        }
    }

    return SUCCESS;

}

}
