#include <DeepLabV3.h>

#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>

#include <Filesystem.h>
#include <SimpleLog.h>
#include <algorithm>

namespace migraphxSamples
{

static float Sigmoid(float x) { return (1 / (1 + exp(-x))); }

static std::vector<float> softmax(vector<float> v){

    std::vector<float> probs(v.size());
    float max_val = *std::max_element(v.begin(),v.end());
    float sum_exp = 0.0f;
    for (int i = 0; i < v.size(); ++i) {
        probs[i] = std::exp(v[i] - max_val);
        sum_exp += probs[i];
    }
    for (int i = 0; i < v.size(); ++i) {
       probs[i] /= sum_exp;
    }

    return probs;
}

// 定义21个类别的颜色映射表（BGR格式）
std::vector<cv::Scalar> create_color_map() {
    return {
        cv::Scalar(0, 0, 0),       // 0: 黑色（背景）
        cv::Scalar(255, 0, 0),     // 1: 蓝色
        cv::Scalar(0, 255, 0),     // 2: 绿色
        cv::Scalar(0, 0, 255),     // 3: 红色
        cv::Scalar(255, 255, 0),   // 4: 青色
        cv::Scalar(255, 0, 255),   // 5: 品红
        cv::Scalar(0, 255, 255),   // 6: 黄色
        cv::Scalar(128, 0, 0),     // 7: 深蓝
        cv::Scalar(0, 128, 0),     // 8: 深绿
        cv::Scalar(0, 0, 128),     // 9: 深红
        cv::Scalar(128, 128, 0),   // 10: 深青
        cv::Scalar(128, 0, 128),   // 11: 深品红
        cv::Scalar(0, 128, 128),   // 12: 深黄
        cv::Scalar(192, 192, 192), // 13: 灰色
        cv::Scalar(128, 128, 128), // 14: 深灰
        cv::Scalar(64, 0, 0),      // 15: 暗蓝
        cv::Scalar(0, 64, 0),      // 16: 暗绿
        cv::Scalar(0, 0, 64),      // 17: 暗红
        cv::Scalar(64, 64, 0),     // 18: 暗青
        cv::Scalar(64, 0, 64),     // 19: 暗品红
        cv::Scalar(0, 64, 64)      // 20: 暗黄
    };
}



DeepLabV3::DeepLabV3() {}

DeepLabV3::~DeepLabV3() { configurationFile.release(); }


ErrorCode DeepLabV3::Initialize(InitializationParameterOfSegmentation initParamOfSegmentationUnet)
{
    // 读取配置文件
    std::string configFilePath = initParamOfSegmentationUnet.configFilePath;
    if(!Exists(configFilePath))
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, cv::FileStorage::READ))
    {
        LOG_ERROR(stdout, "fail to open configuration file\n");
        return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");

    // 获取配置文件参数
    cv::FileNode netNode  = configurationFile["DeepLabV3"];
    std::string modelPath = (std::string)netNode["ModelPath"];

    // 设置最大输入shape
    migraphx::onnx_options onnx_options;
    onnx_options.map_input_dims["inputs"] = {1, 3, 256, 256};

    // 加载模型
    if(!Exists(modelPath))
    {
        LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    net = migraphx::parse_onnx(modelPath, onnx_options);
    LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());

    // 获取模型输入/输出节点信息
    std::unordered_map<std::string, migraphx::shape> inputs  = net.get_inputs();
    std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
    inputName = inputs.begin()->first;
    inputShape = inputs.begin()->second;
    int N = inputShape.lens()[0];
    int C = inputShape.lens()[1];
    int H = inputShape.lens()[2];
    int W = inputShape.lens()[3];
    inputSize = cv::Size(W, H);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id    = 0; // 设置GPU设备，默认为0号设备
    options.offload_copy = true;
    net.compile(gpuTarget, options);
    LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());

    // warm up
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName] = migraphx::argument{inputShape};
    net.eval(inputData);

    // log输出日志信息
    LOG_INFO(stdout, "InputSize:%dx%d\n", inputSize.width, inputSize.height);
    LOG_INFO(stdout, "InputName:%s\n", inputName.c_str());

    return SUCCESS;
}


ErrorCode DeepLabV3::Segmentation(const cv::Mat& srcImage, cv::Mat& maskImage)
{
    if(srcImage.empty() || srcImage.type() != CV_8UC3)
    {
        LOG_ERROR(stdout, "image error!\n");
        return IMAGE_ERROR;
    }

    // 数据预处理并转换为NCHW格式
    cv::Mat inputBlob;
    cv::dnn::blobFromImage(
        srcImage, inputBlob, 1 / 255.0, inputSize, cv::Scalar(0, 0, 0), true, false);

    // 创建输入数据
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName] = migraphx::argument{inputShape, (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> results = net.eval(inputData);

    // 获取输出节点的属性
    migraphx::argument result   = results[0];                 // 获取第一个输出节点的数据
    migraphx::shape outputShape = result.get_shape();         // 输出节点的shape
    std::vector<std::size_t> outputSize = outputShape.lens(); // 每一维大小，维度顺序为(N,C,H,W)

    int numberOfOutput = outputShape.elements();              // 输出节点元素的个数
    float* data        = (float*)result.data();               // 输出节点数据指针

    int N = outputShape.lens()[0];
    int C = outputShape.lens()[1];
    int H = outputShape.lens()[2];
    int W = outputShape.lens()[3];

        
        
    cv::Mat outputImage(cv::Size(W, H), CV_8UC3);
    // 创建颜色映射表
    std::vector<cv::Scalar> color_map = create_color_map();

    for(int i = 0;i < H; i++){
        for(int j = 0;j < W;j++){
            std::vector<float> channel_value;
            for(int k = 0;k < C;k++){
                channel_value.push_back(data[k*(H*W)+i*W+j]);
            }
            std::vector<float> probs = softmax(channel_value);
            // 找到概率最高的类别索引
            int max_index = std::max_element(probs.begin(),probs.end())-probs.begin();
            cv::Scalar sc = color_map[max_index];
            outputImage.at<cv::Vec3b>(i, j)[0]= sc.val[0]; 
            outputImage.at<cv::Vec3b>(i, j)[1]= sc.val[1];
            outputImage.at<cv::Vec3b>(i, j)[2]= sc.val[2];

        }
    }

    maskImage = outputImage.clone();
    return SUCCESS;
}

}