#include <DeepLabV3.h>

#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <hip/hip_runtime_api.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <algorithm>

namespace migraphxSamples{

static std::vector<float> softmax(vector<float> v){

    std::vector<float> probs(v.size());
    float max_val = *std::max_element(v.begin(),v.end());
    float sum_exp = 0.0f;
    for (int i = 0; i < v.size(); ++i) {
        probs[i] = std::exp(v[i] - max_val);
        sum_exp += probs[i];
    }
    for (int i = 0; i < v.size(); ++i) {
       probs[i] /= sum_exp;
    }

    return probs;
}

// 定义21个类别的颜色映射表（BGR格式）
std::vector<cv::Scalar> create_color_map() {
    return {
        cv::Scalar(0, 0, 0),       // 0: 黑色（背景）
        cv::Scalar(255, 0, 0),     // 1: 蓝色
        cv::Scalar(0, 255, 0),     // 2: 绿色
        cv::Scalar(0, 0, 255),     // 3: 红色
        cv::Scalar(255, 255, 0),   // 4: 青色
        cv::Scalar(255, 0, 255),   // 5: 品红
        cv::Scalar(0, 255, 255),   // 6: 黄色
        cv::Scalar(128, 0, 0),     // 7: 深蓝
        cv::Scalar(0, 128, 0),     // 8: 深绿
        cv::Scalar(0, 0, 128),     // 9: 深红
        cv::Scalar(128, 128, 0),   // 10: 深青
        cv::Scalar(128, 0, 128),   // 11: 深品红
        cv::Scalar(0, 128, 128),   // 12: 深黄
        cv::Scalar(192, 192, 192), // 13: 灰色
        cv::Scalar(128, 128, 128), // 14: 深灰
        cv::Scalar(64, 0, 0),      // 15: 暗蓝
        cv::Scalar(0, 64, 0),      // 16: 暗绿
        cv::Scalar(0, 0, 64),      // 17: 暗红
        cv::Scalar(64, 64, 0),     // 18: 暗青
        cv::Scalar(64, 0, 64),     // 19: 暗品红
        cv::Scalar(0, 64, 64)      // 20: 暗黄
    };
}



DeepLabV3::DeepLabV3() {}

DeepLabV3::~DeepLabV3() { configurationFile.release(); }


ErrorCode DeepLabV3::Initialize(InitializationParameterOfSegmentation initParamOfSegmentationUnet){

    // 读取配置文件
    std::string configFilePath = initParamOfSegmentationUnet.configFilePath;
    if(!Exists(configFilePath))
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, cv::FileStorage::READ))
    {
        LOG_ERROR(stdout, "fail to open configuration file\n");
        return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");

    // 获取配置文件参数
    cv::FileNode netNode  = configurationFile["DeepLabV3"];
    std::string modelPath = (std::string)netNode["ModelPath"];
    useInt8               = (bool)(int)netNode["UseInt8"];
    useFP16               = (bool)(int)netNode["UseFP16"];
    useOffloadCopy        = (bool)(int)netNode["UseOffloadCopy"];

    // 加载模型
    if(!Exists(modelPath))
    {
        LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
        return MODEL_NOT_EXIST;
    }

    migraphx::onnx_options onnx_options;
    if(initParamOfSegmentationUnet.loadMode){
        onnx_options.map_input_dims["input"] = {1, 3, 513, 513};
    }else{
        onnx_options.map_input_dims["input"] = {3, 3, 513, 513};
    }
    net = migraphx::parse_onnx(modelPath,onnx_options);
    LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());

    // 获取模型输入/输出节点信息
    std::unordered_map<std::string, migraphx::shape> inputs  = net.get_inputs();
    std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
    inputName = inputs.begin()->first;
    inputShape = inputs.begin()->second;
    outputName                                               = outputs.begin()->first;
    outputShape                                              = outputs.begin()->second;
    auto it = outputs.begin();
    ++it;
    outputName2                                              = it->first;
    outputShape2                                             = it->second;


    int N = inputShape.lens()[0];
    int C = inputShape.lens()[1];
    int H = inputShape.lens()[2];
    int W = inputShape.lens()[3];
    inputSize = cv::Size(W, H);

 
    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    if(useInt8){
        std::vector<cv::Mat> calibrateImages;
        std::string folderPath = "../Resource/Images/calibrateImages/";
        std::string calibrateImageExt = "*.jpg";
        std::vector<cv::String> calibrateImagePaths;
        cv::glob(folderPath + calibrateImageExt, calibrateImagePaths, false);
        for(const auto& path : calibrateImagePaths){
            calibrateImages.push_back(cv::imread(path, 1));
        }
        cv::Mat inputcalibrateBlob;
        cv::dnn::blobFromImages(calibrateImages, inputcalibrateBlob, 1 / 255.0, inputSize, cv::Scalar(0, 0, 0), true, false);
        std::unordered_map<std::string, migraphx::argument> inputData;
        inputData[inputName] = migraphx::argument{inputShape, (float *)inputcalibrateBlob.data};
        std::vector<std::unordered_map<std::string, migraphx::argument>> calibrationData = {inputData};
         // INT8量化
        migraphx::quantize_int8(net, gpuTarget, calibrationData);
    }else{
        migraphx::quantize_fp16(net);
    }

    // 编译模型
    migraphx::compile_options options;
    options.device_id    = 0; // 设置GPU设备，默认为0号设备
    if(useOffloadCopy){
        options.offload_copy = true;
    }else{
        options.offload_copy = false;
    }

    net.compile(gpuTarget, options);
    LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
    if(!useOffloadCopy){
        inputBufferDevice = nullptr;
        hipMalloc(&inputBufferDevice, inputShape.bytes());
        modalDataMap[inputName] = migraphx::argument{inputShape, inputBufferDevice};                   
    
        outputBufferDevice = nullptr;
        hipMalloc(&outputBufferDevice, outputShape.bytes());
        outputBufferDevice2 = nullptr;
        hipMalloc(&outputBufferDevice2, outputShape2.bytes());
        modalDataMap[outputName] = migraphx::argument{outputShape, outputBufferDevice};
        modalDataMap[outputName2] = migraphx::argument{outputShape2, outputBufferDevice2};
        outputBufferHost             = nullptr; // host内存
        outputBufferHost             = malloc(outputShape.bytes());
        outputBufferHost2            = nullptr; // host内存
        outputBufferHost2            = malloc(outputShape2.bytes());
    }
    // warm up
    if(useOffloadCopy){
        std::unordered_map<std::string, migraphx::argument> inputData;                                   
        inputData[inputName] = migraphx::argument{inputShape};                                           
        net.eval(inputData);
    }else{
        migraphx::argument inputData = migraphx::argument{inputShape};                                  //创建数据
        hipMemcpy(inputBufferDevice, inputData.data(), inputShape.bytes(), hipMemcpyHostToDevice);      //将数据复制到device上
        net.eval(modalDataMap);
    }

    // log输出日志信息
    LOG_INFO(stdout, "InputSize:%dx%d\n", inputSize.width, inputSize.height);
    LOG_INFO(stdout, "InputName:%s\n", inputName.c_str());
    LOG_INFO(stdout, "UseInt8:%d\n", (int)useInt8);
    LOG_INFO(stdout, "UseFP16:%d\n", (int)useFP16);
    LOG_INFO(stdout, "useOffloadCopy:%d\n", (int)useOffloadCopy);


    return SUCCESS;
}


ErrorCode DeepLabV3::Segmentation(std::vector<cv::Mat> srcImages, std::vector<cv::Mat> & maskImages){

    if(srcImages.size()==0 || srcImages[0].empty() || srcImages[0].type() != CV_8UC3)
    {
        LOG_ERROR(stdout, "image error!\n");
        return IMAGE_ERROR;
    }

    // 数据预处理并转换为NCHW格式
    cv::Mat inputBatchBlob;
    cv::dnn::blobFromImages(srcImages, inputBatchBlob, 1 / 255.0, inputSize, cv::Scalar(0, 0, 0), true, false);

    // 创建颜色映射表
    std::vector<cv::Scalar> color_map = create_color_map();

    if(useOffloadCopy){
        // 创建输入数据
        std::unordered_map<std::string, migraphx::argument> inputData;
        inputData[inputName] = migraphx::argument{inputShape, (float*)inputBatchBlob.data};

        // 推理
        std::vector<migraphx::argument> results = net.eval(inputData);
        
        // 获取输出节点的属性
        migraphx::argument result   = results[0];                 // 获取第一个输出节点的数据
        migraphx::shape outputShape = result.get_shape();         // 输出节点的shape
        std::vector<std::size_t> outputSize = outputShape.lens(); // 每一维大小，维度顺序为(N,C,H,W)

        int numberOfOutput = outputShape.elements();              // 输出节点元素的个数
        float* data        = (float*)result.data();               // 输出节点数据指针

        int N = outputShape.lens()[0];
        int C = outputShape.lens()[1];
        int H = outputShape.lens()[2];
        int W = outputShape.lens()[3];

        for(int m = 0;m < N;m++){
            cv::Mat outputImage(cv::Size(W, H), CV_8UC3);
            for(int i = 0;i < H; i++){
                for(int j = 0;j < W;j++){
                    std::vector<float> channel_value;
                    for(int k = 0;k < C;k++){
                        channel_value.push_back(data[m*C*H*W+k*(H*W)+i*W+j]);
                    }
                    std::vector<float> probs = softmax(channel_value);
                    // 找到概率最高的类别索引
                    int max_index = std::max_element(probs.begin(),probs.end())-probs.begin();
                    cv::Scalar sc = color_map[max_index];
                    outputImage.at<cv::Vec3b>(i, j)[0]= sc.val[0]; 
                    outputImage.at<cv::Vec3b>(i, j)[1]= sc.val[1];
                    outputImage.at<cv::Vec3b>(i, j)[2]= sc.val[2];
                }
            }
            maskImages.push_back(outputImage);
        }
        
    }else{

        migraphx::argument inputData = migraphx::argument{inputShape, (float*)inputBatchBlob.data};
        // 拷贝到device输入内存
        hipMemcpy(inputBufferDevice, inputData.data(), inputShape.bytes(), hipMemcpyHostToDevice);
        // 推理
        std::vector<migraphx::argument> results = net.eval(modalDataMap);

        // 获取输出节点的属性
        migraphx::argument result    = results[0];                                      // 获取第一个输出节点的数据
        migraphx::shape outputShapes = result.get_shape();                              // 输出节点的shape
        std::vector<std::size_t> outputSize = outputShapes.lens();                      // 每一维大小，维度顺序为(N,C,H,W)
        int numberOfOutput = outputShapes.elements();                                   // 输出节点元素的个数
        // 将device输出数据拷贝到分配好的host输出内存
        hipMemcpy(outputBufferHost,outputBufferDevice, outputShapes.bytes(),hipMemcpyDeviceToHost); // 直接使用事先分配好的输出内存拷贝
        int N = outputSize[0];
        int C = outputSize[1];
        int H = outputSize[2];
        int W = outputSize[3];
       


        // 获取输出节点的属性
        migraphx::argument result2    = results[1];                                         // 获取第2个输出节点的数据
        migraphx::shape outputShapes2 = result2.get_shape();                                // 输出节点的shape
        std::vector<std::size_t> outputSize2 = outputShapes2.lens();                        // 每一维大小，维度顺序为(N,C,H,W)
         // 将device输出数据拷贝到分配好的host输出内存
        hipMemcpy(outputBufferHost2,outputBufferDevice2, outputShapes2.bytes(),hipMemcpyDeviceToHost); // 直接使用事先分配好的输出内存拷贝
        for(int m = 0;m < N;m++){
            cv::Mat outputImage(cv::Size(W, H), CV_8UC3);
            for(int i = 0;i < H; i++){
                for(int j = 0;j < W;j++){
                    std::vector<float> channel_value;
                    for(int k = 0;k < C;k++){
                        channel_value.push_back(((float *)outputBufferDevice2)[m*C*H*W+k*(H*W)+i*W+j]);
                    }
                    std::vector<float> probs = softmax(channel_value);
                    // 找到概率最高的类别索引
                    int max_index = std::max_element(probs.begin(),probs.end())-probs.begin();
                    cv::Scalar sc = color_map[max_index];
                    outputImage.at<cv::Vec3b>(i, j)[0]= sc.val[0]; 
                    outputImage.at<cv::Vec3b>(i, j)[1]= sc.val[1];
                    outputImage.at<cv::Vec3b>(i, j)[2]= sc.val[2];
                }
            }
            maskImages.push_back(outputImage);
        }
        // 释放
        hipFree(inputBufferDevice);
        hipFree(outputBufferDevice);
        hipFree(outputBufferDevice2);
        free(outputBufferHost);
        free(outputBufferHost2);
    }    

    return SUCCESS;
}

}