#include <Unet.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <opencv2/dnn.hpp>
#include <CommonUtility.h>
#include <Filesystem.h>
#include <SimpleLog.h>

using namespace cv::dnn;

namespace migraphxSamples
{

Unet::Unet():logFile(NULL)
{
    
}

Unet::~Unet()
{
    configurationFile.release();
}

ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegmentationUnet)
{
    // 初始化(获取日志文件,加载配置文件等)
    ErrorCode errorCode=DoCommonInitialization(initParamOfSegmentationUnet);
    if(errorCode!=SUCCESS)
    {
        LOG_ERROR(logFile,"fail to DoCommonInitialization\n");
        return errorCode;
    }
    LOG_INFO(logFile,"succeed to DoCommonInitialization\n");

    // 获取配置文件参数
    FileNode netNode = configurationFile["Unet"];
    std::string modelPath=initializationParameter.parentPath+(std::string)netNode["ModelPath"];

    // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    net = migraphx::parse_onnx(modelPath);       
    LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

    // 获取模型输入属性
    std::pair<std::string, migraphx::shape> inputAttribute=*(net.get_parameter_shapes().begin());
    inputName=inputAttribute.first;
    inputShape=inputAttribute.second;
    inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备，默认为0号设备
    options.offload_copy=true;                    // 设置offload_copy
    net.compile(gpuTarget,options);
    LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // Run once by itself
    migraphx::parameter_map inputData;
    inputData[inputName]=migraphx::generate_argument(inputShape);
    net.eval(inputData);                          

    // log输出日志信息
    LOG_INFO(logFile,"InputSize:%dx%d\n",inputSize.width,inputSize.height);
    LOG_INFO(logFile,"InputName:%s\n",inputName.c_str());

    return SUCCESS;
}

ErrorCode Unet::DoCommonInitialization(InitializationParameterOfSegmentation initParamOfSegmentationUnet)
{
    initializationParameter = initParamOfSegmentationUnet;

    // 获取日志文件
    logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName);

    // 加载配置文件
    std::string configFilePath=initializationParameter.configFilePath;
    if(!Exists(configFilePath)) 
    {
        LOG_ERROR(logFile, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, FileStorage::READ))
    {
        LOG_ERROR(logFile, "fail to open configuration file\n");
        return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(logFile, "succeed to open configuration file\n");

    // 修改父路径
    std::string &parentPath = initializationParameter.parentPath;
    if (!parentPath.empty())
    {
        if(!IsPathSeparator(parentPath[parentPath.size() - 1]))
        {
           parentPath+=PATH_SEPARATOR;
        }
    }
    
    return SUCCESS;
}

ErrorCode Unet::Segmentation(const cv::Mat &srcImage, cv::Mat &maskImage)
{
    if(srcImage.empty()||srcImage.type()!=CV_8UC3)
    {
        LOG_ERROR(logFile, "image error!\n");
        return IMAGE_ERROR;
    }

    // 图像预处理并转换为NCHW
    cv::Mat inputBlob;
    cv::dnn::blobFromImage(srcImage,
                    inputBlob,
                    1 / 255.0,
                    inputSize,
                    Scalar(0, 0, 0),
                    true,
                    false);

    // 输入数据
    migraphx::parameter_map inputData;
    inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> results = net.eval(inputData);

    // 获取输出节点的属性
    migraphx::argument result = results[0];                 // 获取第一个输出节点的数据
    migraphx::shape outputShape=result.get_shape();         // 输出节点的shape
    std::vector<std::size_t> outputSize=outputShape.lens(); // 每一维大小，维度顺序为(N,C,H,W) 
    int numberOfOutput=outputShape.elements();              // 输出节点元素的个数
    float *data = (float *)result.data();                   // 输出节点数据指针

    // 计算sigmoid值，并且当大于0.996时，值为1,当小于0.996时，值为0，存储在value_mask[]数组中
    int value_mask[numberOfOutput];
    for(int i=0; i<numberOfOutput; ++i)
    {
        float num  = Sigmoid(data[i]);
        if (num > 0.996)
        {
            value_mask[i] = 1;
        }
        else
        {
            value_mask[i] = 0;
        }
    }

    // 将对应的value_mask[]数组中的值按行依次赋值到outputImage对应位置处
    cv::Mat outputImage = cv::Mat_<int>(Size(outputShape.lens()[3], outputShape.lens()[2]), CV_32S);
    for(int i=0;i<outputShape.lens()[2];++i)
    {
        for(int j=0;j<outputShape.lens()[3];++j)
        {
            outputImage.at<int>(i,j)=value_mask[256*i+j];   //其中，256代表了outputShape.lens()[3]的值
        }
    }
    outputImage.convertTo(maskImage, CV_8U, 255.0);

    return SUCCESS;

}

float Unet::Sigmoid(float x)
{
    return (1 / (1 + exp(-x)));
}

}
