#include <OcrSVTR.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/reshape2.hpp>
#include <opencv2/dnn.hpp>
#include <CommonUtility.h>
#include <Filesystem.h>
#include <SimpleLog.h>

using namespace cv::dnn;

namespace migraphxSamples
{
SVTR::SVTR():logFile(NULL)
{

}

SVTR::~SVTR()
{

    configurationFile.release();
    
}

ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameterOfSVTR)
{
    // 初始化(获取日志文件，加载配置文件等)
    ErrorCode errorCode = DoCommonInitialization(InitializationParameterOfSVTR);
    if (errorCode!=SUCCESS)
    {
        LOG_ERROR(logFile, "fail to DoCommonInitialization\n");
        return errorCode;
    }
    LOG_INFO(logFile, "success to DoCommonInitialization\n");

    // 获取配置文件参数
    FileNode netNode = configurationFile["OcrSVTR"];
    string modelPath = initializationParameter.parentPath + (string)netNode["ModelPath"];
    string dictPath = (string)netNode["DictPath"];

     // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }

    migraphx::onnx_options onnx_options;
    onnx_options.map_input_dims["x"]={1,3,48,320};
    net = migraphx::parse_onnx(modelPath, onnx_options);
    LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

    // 获取模型输入属性
    std::pair<std::string, migraphx::shape> inputAttribute=*(net.get_parameter_shapes().begin());
    inputName=inputAttribute.first;
    inputShape=inputAttribute.second;
    inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备，默认为0号设备
    options.offload_copy=true;                    // 设置offload_copy
    net.compile(gpuTarget,options);               
    LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // Run once by itself
    migraphx::parameter_map inputData;
    inputData[inputName]=migraphx::generate_argument(inputShape);
    net.eval(inputData);  

    std::ifstream in(dictPath);
    std::string line;
    if (in)
    {
        while (getline(in, line))
        {
            charactorDict.push_back(line);
        }
        charactorDict.insert(charactorDict.begin(), "#");
        charactorDict.push_back(" ");
    }
    else
    {
        std::cout << "no such label file: " << dictPath << ", exit the program..." << std::endl;
        exit(1);
    }

    // log
    LOG_INFO(logFile,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height);
    LOG_INFO(logFile,"InputName:%s\n",inputName.c_str());                        

    return SUCCESS;
}

ErrorCode SVTR::Infer(cv::Mat &img, std::string &resultsChar, float &resultsdScore, float &maxWHRatio)
{
    if(img.empty()||img.type()!=CV_8UC3)
    {
        LOG_ERROR(logFile, "image error!\n");
        return IMAGE_ERROR;
    }

    cv::Mat srcImage;
    cv::Mat resizeImg;
    img.copyTo(srcImage);

    float ratio = 1.f;
    int imgC = 3, imgH = 48;
    int resizeW;
    int imgW = int((48 * maxWHRatio));
    ratio = float(srcImage.cols) / float(srcImage.rows);
    if (ceil(imgH * ratio) > imgW)
    {
        resizeW = imgW;
    }
    else
    {
        resizeW = int(ceil(imgH * ratio));
    }
    cv::resize(srcImage, resizeImg, cv::Size(resizeW, imgH));
    cv::copyMakeBorder(resizeImg, resizeImg, 0, 0, 0,
                     int(imgW - resizeImg.cols), cv::BORDER_CONSTANT,
                     {127, 127, 127});

    resizeImg.convertTo(resizeImg, CV_32FC3, 1.0/255.0);
    std::vector<cv::Mat> bgrChannels(3);
    cv::split(resizeImg, bgrChannels);
    std::vector<float> mean = {0.485f, 0.456f, 0.406f};
    std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
    for (auto i = 0; i < bgrChannels.size(); i++)
    {
        bgrChannels[i].convertTo(bgrChannels[i], CV_32FC1, 1.0 * scale[i],
                              (0.0 - mean[i]) * scale[i]);
    }
    cv::merge(bgrChannels, resizeImg);
    cv::Mat inputBlob = cv::dnn::blobFromImage(resizeImg);

    std::vector<std::size_t> inputShapeOfInfer={1,3,48,resizeW};

    // 输入数据
    migraphx::parameter_map inputData;
    inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),inputShapeOfInfer), (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> inferenceResults = net.eval(inputData);
    
    // 获取推理结果
    migraphx::argument result = inferenceResults[0];
    migraphx::shape outputShape = result.get_shape();
    int n2 = outputShape.lens()[1];
    int n3 = outputShape.lens()[2];
    int n = n2 * n3;
    std::vector<float> out(n);
    memcpy(out.data(),result.data(),sizeof(float)*outputShape.elements());
    out.resize(n);

    int argmaxIdx;
    int lastIndex = 0;
    float score = 0.f;
    int count = 0;
    float maxValue = 0.0f;
    for (int j = 0; j < n2; j++)
    {
        argmaxIdx = int(std::distance(&out[(j) * n3], 
                std::max_element(&out[(j) * n3], &out[(j + 1) * n3])));
        maxValue = float(*std::max_element(&out[(j) * n3], 
                &out[(j + 1) * n3]));

        if (argmaxIdx > 0 && (!(n > 0 && argmaxIdx == lastIndex))) 
            {
                score += maxValue;
                count += 1;
                resultsChar += charactorDict[argmaxIdx];
            }
        lastIndex = argmaxIdx;
    }
    resultsdScore = score / count;

    return SUCCESS;
}

ErrorCode SVTR::DoCommonInitialization(InitializationParameterOfSVTR InitializationParameterOfSVTR)
{
    initializationParameter = InitializationParameterOfSVTR;

    // 获取日志文件
    logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName);

    // 加载配置文件
    std::string configFilePath=initializationParameter.configFilePath;
    if(!Exists(configFilePath))
    {
        LOG_ERROR(logFile, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, FileStorage::READ))
    {
       LOG_ERROR(logFile, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(logFile, "succeed to open configuration file\n");

    // 修改父路径
    std::string &parentPath = initializationParameter.parentPath;
    if (!parentPath.empty())
    {
        if(!IsPathSeparator(parentPath[parentPath.size() - 1]))
        {
           parentPath+=PATH_SEPARATOR;
        }
    }

    return SUCCESS;

}


}
