#include <Crnn.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/reshape2.hpp>
#include <opencv2/dnn.hpp>
#include <CommonUtility.h>
#include <Filesystem.h>
#include <SimpleLog.h>

using namespace cv::dnn;

namespace migraphxSamples
{

Crnn::Crnn():logFile(NULL)
{

}

Crnn::~Crnn()
{

    configurationFile.release();
    
}

ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterOfOcr, bool dynamic)
{
    // 初始化(获取日志文件,加载配置文件等)
    ErrorCode errorCode = DoCommonInitialization(initializationParameterOfOcr);
    if(errorCode!=SUCCESS)
    {
        LOG_ERROR(logFile,"fail to DoCommonInitialization\n");
        return errorCode;
    }
    LOG_INFO(logFile,"succeed to DoCommonInitialization\n");

    // 获取配置文件参数
    std::string modelPath;
    FileNode netNode = configurationFile["CrnnDynamic"];
    modelPath = initializationParameter.parentPath + (string)netNode["ModelPath"];

    // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }

    if(dynamic)
    {   
        migraphx::onnx_options onnx_options;
        onnx_options.map_input_dims["input"]={1,1,32,512};

        net = migraphx::parse_onnx(modelPath, onnx_options);     
        LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

        // 获取模型输入属性
        std::pair<std::string, migraphx::shape> inputAttribute=*(net.get_parameter_shapes().begin());
        inputName=inputAttribute.first;
        inputShape=inputAttribute.second;
        inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]);

        // log输出日志信息
        LOG_INFO(logFile,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height);
        LOG_INFO(logFile,"InputName:%s\n",inputName.c_str());
    }
    else
    {
        migraphx::onnx_options onnx_options;
        onnx_options.map_input_dims["input"]={1,1,32,100};

        net = migraphx::parse_onnx(modelPath, onnx_options);       
        LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

        // 获取模型输入属性
        std::pair<std::string, migraphx::shape> inputAttribute=*(net.get_parameter_shapes().begin());
        inputName=inputAttribute.first;
        inputShape=inputAttribute.second;
        inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]);

        // log输出日志信息
        LOG_INFO(logFile,"InputSize:%dx%d\n",inputSize.width,inputSize.height);
        LOG_INFO(logFile,"InputName:%s\n",inputName.c_str());
    }

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备，默认为0号设备
    options.offload_copy=true;                    // 设置offload_copy
    net.compile(gpuTarget,options);               
    LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // Run once by itself
    migraphx::parameter_map inputData;
    inputData[inputName]=migraphx::generate_argument(inputShape);
    net.eval(inputData);                          

    return SUCCESS;
}

ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, bool raw, bool dynamic)
{
    if(srcImage.empty() || srcImage.type()!=CV_8UC3)
    {
        LOG_ERROR(logFile, "image error!\n");
        return IMAGE_ERROR;
    }

    cv::Mat inputImage, inputBlob;
    cv::cvtColor(srcImage, inputImage, CV_BGR2GRAY);
    
    int height, width, widthRaw;
    widthRaw = inputImage.cols;
    if(dynamic)
    {
        cv::resize(inputImage, inputImage, cv::Size(widthRaw, 32));
        height = inputImage.rows, width = inputImage.cols;
    }
    else
    {
        cv::resize(inputImage, inputImage, cv::Size(100, 32));
        height = inputImage.rows, width = inputImage.cols;
    }
    inputBlob = cv::dnn::blobFromImage(inputImage);
    
    for(int i=0; i<width * height; i++)
    {
        *((float*)inputBlob.data+i) = ((*((float*)inputBlob.data+i))/255.f - 0.5)/0.5;
    }
    
    // 输入数据
    migraphx::parameter_map inputData;
    if(dynamic)
    {
        std::vector<std::size_t> dynamicShape = {1, 1, 32, width};
        inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),dynamicShape), (float*)inputBlob.data};
    }
    else
    {
        inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};
    }
    
    // 推理
    std::vector<migraphx::argument> inferenceResults = net.eval(inputData);
     
    // 获取推理结果
    std::vector<cv::Mat> outs;
    migraphx::argument result = inferenceResults[0]; 

    // 转换为cv::Mat
    migraphx::shape outputShape = result.get_shape();
    int shape[]={outputShape.lens()[0],outputShape.lens()[1],outputShape.lens()[2]};
    cv::Mat out(3,shape,CV_32F);
    memcpy(out.data,result.data(),sizeof(float)*outputShape.elements());
    outs.push_back(out);

    std::vector<int> predChars; 
    const std::string alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz";

    //获取字符索引序列
    for(uint i = 0; i < outs[0].size[0]; i++)
    {
        cv::Mat scores = Mat(1,outs[0].size[2],CV_32F,outs[0].ptr<float>(i));
        cv::Point charIdPoint;
        double maxCharScore;
        cv::minMaxLoc(scores, 0, &maxCharScore, 0, &charIdPoint);
        int maxIdx = charIdPoint.x;
        predChars.push_back(maxIdx);
    }
    
    //字符转录处理
    for(uint i=0; i<predChars.size(); i++)
    {
        if(raw)
        {
            resultsChar.push_back(alphabet[predChars[i]]);
        }
        else
        {
            if(predChars[i] != 0)
            {
                if(!(i > 0 && predChars[i-1]==predChars[i]))
                {
                    resultsChar.push_back(alphabet[predChars[i]]);
                }
            }
        }
    }

    return SUCCESS;
}

ErrorCode Crnn::DoCommonInitialization(InitializationParameterOfOcr initializationParameterOfOcr)
{
    initializationParameter=initializationParameterOfOcr;

    // 获取日志文件
    logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName);

    // 加载配置文件
    std::string configFilePath=initializationParameter.configFilePath;
    if(!Exists(configFilePath))
    {
        LOG_ERROR(logFile, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, FileStorage::READ))
    {
       LOG_ERROR(logFile, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(logFile, "succeed to open configuration file\n");

    // 修改父路径
    std::string &parentPath = initializationParameter.parentPath;
    if (!parentPath.empty())
    {
        if(!IsPathSeparator(parentPath[parentPath.size() - 1]))
        {
           parentPath+=PATH_SEPARATOR;
        }
    }

    return SUCCESS;

}

}
