#include <Unet.h>

#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>

#include <Filesystem.h>
#include <SimpleLog.h>

namespace migraphxSamples
{

static float Sigmoid(float x)
{
    return (1 / (1 + exp(-x)));
}

Unet::Unet()
{
    
}

Unet::~Unet()
{
    configurationFile.release();
}

ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegmentationUnet)
{
    // 读取配置文件
    std::string configFilePath=initParamOfSegmentationUnet.configFilePath;
    if(Exists(configFilePath)==false)
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, cv::FileStorage::READ))
    {
       LOG_ERROR(stdout, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");

    // 获取配置文件参数
    cv::FileNode netNode = configurationFile["Unet"];
    std::string modelPath=(std::string)netNode["ModelPath"];

    // 设置最大输入shape
    migraphx::onnx_options onnx_options;
    onnx_options.map_input_dims["inputs"]={1,3,256,256};

    // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    net = migraphx::parse_onnx(modelPath, onnx_options);       
    LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

    // 获取模型输入/输出节点信息
    std::cout<<"inputs:"<<std::endl;
    std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
    for(auto i:inputs)
    {
        std::cout<<i.first<<":"<<i.second<<std::endl;
    }
    std::cout<<"outputs:"<<std::endl;
    std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
    for(auto i:outputs)
    {
        std::cout<<i.first<<":"<<i.second<<std::endl;
    }
    inputName=inputs.begin()->first;
    inputShape=inputs.begin()->second;
    int N=inputShape.lens()[0];
    int C=inputShape.lens()[1];
    int H=inputShape.lens()[2];
    int W=inputShape.lens()[3];
    inputSize=cv::Size(W,H);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备，默认为0号设备
    options.offload_copy=true;                    
    net.compile(gpuTarget,options);
    LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // warm up
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]=migraphx::argument{inputShape};
    net.eval(inputData);                          

    // log输出日志信息
    LOG_INFO(stdout,"InputSize:%dx%d\n",inputSize.width,inputSize.height);
    LOG_INFO(stdout,"InputName:%s\n",inputName.c_str());

    return SUCCESS;
}

ErrorCode Unet::Segmentation(const cv::Mat &srcImage, cv::Mat &maskImage)
{
    if(srcImage.empty()||srcImage.type()!=CV_8UC3)
    {
        LOG_ERROR(stdout, "image error!\n");
        return IMAGE_ERROR;
    }

    // 数据预处理并转换为NCHW格式
    cv::Mat inputBlob;
    cv::dnn::blobFromImage(srcImage,
                    inputBlob,
                    1 / 255.0,
                    inputSize,
                    cv::Scalar(0, 0, 0),
                    true,
                    false);

    // 创建输入数据
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> results = net.eval(inputData);

    // 获取输出节点的属性
    migraphx::argument result = results[0];                 // 获取第一个输出节点的数据
    migraphx::shape outputShape=result.get_shape();         // 输出节点的shape
    std::vector<std::size_t> outputSize=outputShape.lens(); // 每一维大小，维度顺序为(N,C,H,W) 
    int numberOfOutput=outputShape.elements();              // 输出节点元素的个数
    float *data = (float *)result.data();                   // 输出节点数据指针

    // 计算sigmoid值
    int value_mask[numberOfOutput];
    for(int i=0; i<numberOfOutput; ++i)
    {
        float num  = Sigmoid(data[i]);
        if (num > 0.996)
        {
            value_mask[i] = 1;
        }
        else
        {
            value_mask[i] = 0;
        }
    }

    // 将对应的value_mask[]数组中的值按行依次赋值到outputImage对应位置处
    cv::Mat outputImage = cv::Mat_<int>(cv::Size(outputShape.lens()[3], outputShape.lens()[2]), CV_32S);
    for(int i=0;i<outputShape.lens()[2];++i)
    {
        for(int j=0;j<outputShape.lens()[3];++j)
        {
            outputImage.at<int>(i,j)=value_mask[256*i+j];
        }
    }
    outputImage.convertTo(maskImage, CV_8U, 255.0);

    return SUCCESS;

}



}
