#include "DetectorYOLOV5.h"
#include <filesystem>
using namespace std;
using namespace cv::dnn;
namespace fs = std::filesystem;
DetectorYOLOV5::~DetectorYOLOV5()
{
    delete session;
    inputNamesPtr.clear();
    outputNamesPtr.clear();
}

void DetectorYOLOV5::setGpuIndex(int gpuIndex)
{
    if (gpuIndex >= 0) {
        OrtROCMProviderOptions rocm_options;
        rocm_options.device_id = gpuIndex;
        rocm_options.arena_extend_strategy = 0;
        rocm_options.miopen_conv_exhaustive_search = 0;
        rocm_options.do_copy_in_default_stream = 1;

        sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
    }
    else {
        printf("det use CPU\n");
    }
}

void DetectorYOLOV5::setNumThread(int numOfThread)
{
    numThread = numOfThread;
    sessionOptions.SetInterOpNumThreads(numThread);
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
}

void DetectorYOLOV5::initModel(const std::string &pathStr, const std::string &pathOfClassNameFile)
{
    fprintf(stdout, "Start init model for %s\n", pathStr.c_str());
    session = new Ort::Session(env, pathStr.c_str(), sessionOptions);
    inputDim = getInputDim(session);
    inputNamesPtr = getInputNames(session);
    outputNamesPtr = getOutputNames(session);
    std::ifstream classNameFile(pathOfClassNameFile);
    string line;
    while (getline(classNameFile, line))
    {
        classNames.push_back(line);
    }
}

std::vector<int64_t> DetectorYOLOV5::getshape()
{
    return this->inputDim;
}


void DetectorYOLOV5::Detect(std::vector<cv::Mat> originSrc, std::vector<cv::String> imageNames)
{
    std::vector<ResultOfDetection> resultsOfDetection;
    cv::Size inputSize = cv::Size(inputDim[3], inputDim[2]);

    int batch_num = originSrc.size();

    std::vector<float> inputTensorData;
    // size_t size = inputTensorData.size();
    int imgchannels = 3;

    for(int nbatch = 0; nbatch < batch_num; nbatch++)
    {
        cv::Mat src = originSrc[nbatch];
        cv::Mat srcresize;
        cv::resize(src, srcresize, inputSize, 0, 0, cv::INTER_LINEAR);
        // cv::Mat srcresize = src;
        std::vector<float> inputTensorValues = substractMeanNormalize(srcresize, meanValues, normValues);
        inputTensorData.insert(inputTensorData.end(),inputTensorValues.begin(), inputTensorValues.end());
    }

    std::array<int64_t, 4> inputShape{batch_num, imgchannels, inputDim[2], inputDim[3]};
    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, inputTensorData.data(), inputTensorData.size(), inputShape.data(), inputShape.size());

    std::vector<const char *> inputNames = {inputNamesPtr.data()->get()};
    std::vector<const char *> outputNames = {outputNamesPtr.data()->get()};

    // 推理
    auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor, inputNames.size(), outputNames.data(), outputNames.size());

    // 获取推理结果
    std::vector<cv::Mat> outs;
    std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
    float *floatArray = outputTensor.front().GetTensorMutableData<float>();

    // 转换为cv::Mat
    int shape[]={static_cast<int>(outputShape[0]), static_cast<int>(outputShape[1]), static_cast<int>(outputShape[2])};
    cv::Mat out(3, shape, CV_32F, floatArray);
    float* resptr = out.ptr<float>(0);
    outs.push_back(out);

    for(int batch = 0; batch < batch_num; batch++)
    {
        cv::String imageName = imageNames[batch];
        float* dataptr = resptr + batch * out.size[1] * out.size[2];
        int shape_b1[] = {1, out.size[1], out.size[2]};
        cv::Mat batch_image(3, shape_b1, CV_32F, dataptr);
        postprocess(batch_image, originSrc[batch], inputSize, resultsOfDetection, imageName);
    }
}

void DetectorYOLOV5::postprocess(cv::Mat& res, const cv::Mat& originSrc, const cv::Size& inputSize, std::vector<ResultOfDetection> &resultsOfDetection, cv::String& imageName)
{
    //获取先验框的个数
    int numProposal = res.size[1];
    cout << "numProposal: " << numProposal << endl;
    //每个先验框的信息(85=80个类别概率+4个预测框+1对象得分)
    int numOut = res.size[2];
    cout << "numOut: " << numOut << endl;
    //变换输出的维度
    res = res.reshape(0, numProposal);
    
    //生成先验框
    std::vector<float> confidences;
    std::vector<cv::Rect> boxes;
    std::vector<int> classIds;
    float ratioh = (float)originSrc.rows / inputSize.height; 
    float ratiow = (float)originSrc.cols / inputSize.width;

    //计算cx,cy,w,h,box_sore,class_sore
    int n = 0, rowInd = 0;
    float* pdata = (float*)res.data;
    for (n = 0; n < numProposal; n++)
    {
        float boxScores = pdata[4];
        if (boxScores > 0.5)
        {
            cv::Mat scores = res.row(rowInd).colRange(5, numOut);
            cv::Point classIdPoint;
            double maxClassScore;
            cv::minMaxLoc(scores, 0, &maxClassScore, 0, &classIdPoint);
            maxClassScore *= boxScores;
            if (maxClassScore > 0.25)
            {
                const int classIdx = classIdPoint.x;
                float cx = pdata[0] * ratiow;
                float cy = pdata[1] * ratioh;
                float w = pdata[2] * ratiow;
                float h = pdata[3] * ratioh;

                int left = int(cx - 0.5 * w);
                int top = int(cy - 0.5 * h);

                confidences.push_back((float)maxClassScore);
                boxes.push_back(cv::Rect(left, top, (int)(w), (int)(h)));
                classIds.push_back(classIdx);
            }
        }
        rowInd++;
        pdata += numOut;
    }

    //执行nms消除冗余重叠boxes
    std::vector<int> indices;
    dnn::NMSBoxes(boxes, confidences, 0.25, 0.5, indices);
    for (size_t i = 0; i < indices.size(); ++i)
    {
        int idx = indices[i];
        int classID=classIds[idx];
        string className=classNames[classID];
        float confidence=confidences[idx];
        cv::Rect box = boxes[idx];

        ResultOfDetection result;
        result.boundingBox=box;
        result.confidence=confidence;// confidence
        result.classID=classID; // label
        result.className=className;
        resultsOfDetection.push_back(result);
    }
    cout << "num of boxes: " << resultsOfDetection.size() << endl;

    fprintf(stdout,"//////////////Detection Results//////////////\n");
    for( size_t i = 0; i < resultsOfDetection.size(); ++i)
    {
        ResultOfDetection result = resultsOfDetection[i];
        cv::rectangle(originSrc, result.boundingBox, Scalar(0,255,255),2);
        cv::putText(originSrc, result.className, cv::Point(result.boundingBox.x, result.boundingBox.y-20), cv::FONT_HERSHEY_PLAIN, 2.0, Scalar(0, 0, 255), 2);

        fprintf(stdout,"box:%d %d %d %d,label:%d,confidence:%.3f\n",result.boundingBox.x,
        result.boundingBox.y,result.boundingBox.width,result.boundingBox.height,result.classID,result.confidence);
    }
    fs::path path(imageName);
    fs::path filename = path.filename();
    fs::path outpath = fs::path("../build") / filename;
    cv::imwrite(outpath.string(), originSrc);
    resultsOfDetection.clear();
}