#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <opencv2/opencv.hpp>
#include <fstream>
#include <iostream>
#include <vector>
#include <string>
#include <cmath>
#include "Helpers.h"


// Softmax function to convert logits to probabilities
std::vector<float> softmax(const std::vector<float>& logits) {
    float max_logit = *std::max_element(logits.begin(), logits.end());
    float sum_exp = 0.0f;
    std::vector<float> probs(logits.size());

    for (size_t i = 0; i < logits.size(); ++i) {
        probs[i] = exp(logits[i] - max_logit);
        sum_exp += probs[i];
    }

    for (size_t i = 0; i < probs.size(); ++i) {
        probs[i] /= sum_exp;
    }

    return probs;
}

int main(int argc, char *argv[])
{
    // 设置最大输入shape: input表示输入节点名，{8,3,224,224}表示最大输入shape
    migraphx::onnx_options onnx_options;
    onnx_options.map_input_dims["input"] = {8, 3, 224, 224};
    // 加载模型
    migraphx::program net = migraphx::parse_onnx("../models/ResNet50.onnx", onnx_options);
    
    // 加载标签
    const std::string labelFile = "../models/imagenet_classes.txt";
    std::vector<std::string> labels = Helpers::loadLabels(labelFile);
    if (labels.empty()) {
        std::cout << "Failed to load labels: " << labelFile << std::endl;
        return 1;
    }

    // 获取模型输入/输出节点信息
    std::cout << "inputs:" << std::endl;
    std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
    for (auto i : inputs)
    {
        std::cout << i.first << ":" << i.second << std::endl;
    }
    std::cout << "outputs:" << std::endl;
    std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
    for (auto i : outputs)
    {
        std::cout << i.first << ":" << i.second << std::endl;
    }
    std::string inputName = inputs.begin()->first;
    migraphx::shape inputShape = inputs.begin()->second;
    int N = inputShape.lens()[0];
    int C = inputShape.lens()[1];
    int H = inputShape.lens()[2];
    int W = inputShape.lens()[3];
    
    // 编译模型
    migraphx::compile_options options;
    options.device_id = 0; // 设置GPU设备，默认为0号设备
    options.offload_copy = true;
    net.compile(migraphx::gpu::target{}, options);
    
    // 设置动态输入
    std::vector<std::vector<std::size_t>> inputShapes;
    int batch=4;
    inputShapes.push_back({batch, 3, 224, 224});

    // 测试图片的目录
    const std::string imagePath = "../img_in/";
    const std::string _strPattern = imagePath + "*.jpg";
    std::vector<std::string> ImageNames;
    cv::glob(_strPattern, ImageNames);

    for (int i = 0; i < inputShapes.size(); ++i)
    {
        // 数据预处理并转换为NCHW格式
        std::vector<cv::Mat> srcImages;
        for (int j = 0; j < inputShapes[i][0]; ++j)
        {
            cv::Mat srcImage = cv::imread(ImageNames[j], 1); 
            srcImages.push_back(srcImage);
        }
        cv::Mat inputBlob;
        cv::dnn::blobFromImages(srcImages,
                    inputBlob,
                    0.0078125,
                    cv::Size(inputShapes[i][3], inputShapes[i][2]),
                    cv::Scalar(127.5, 127.5, 127.5),
                    false, false);

        // 创建输入数据
        std::unordered_map<std::string, migraphx::argument> inputData;
        inputData[inputName] = migraphx::argument{migraphx::shape(inputShape.type(), inputShapes[i]), (float*)inputBlob.data};
        
        // 推理
        std::vector<migraphx::argument> results = net.eval(inputData);
        
        // 获取输出节点的属性
        migraphx::argument result = results[0];          // 获取第一个输出节点的数据

        migraphx::shape outputShape = result.get_shape();     // 输出节点的shape
        std::vector<std::size_t> outputSize = outputShape.lens(); 
        // 每一维大小，维度顺序为(N,C,H,W)
        int numberOfOutput = outputShape.elements();      // 输出节点元素的个数
        float *resultData = (float *)result.data();        // 输出节点数据指针
        
        // 每张图像的标签输出
        int numClasses = outputSize[1]; // output shape = [batch_size, num_classes, ...]
        for (int imgIndex = 0; imgIndex < batch; ++imgIndex) {
            // Extract logits for the current image
            std::vector<float> logits(resultData + imgIndex * numClasses, resultData + (imgIndex + 1) * numClasses);
            
            // Convert logits to probabilities
            std::vector<float> probs = softmax(logits);

            // Find the index of the maximum probability
            auto max_it = std::max_element(probs.begin(), probs.end());
            int top_label_index = std::distance(probs.begin(), max_it);
            
            // Print the top label for the current image
            printf("Image %d: Top label index: %d, Label: %s\n", imgIndex, top_label_index, labels[top_label_index].c_str());
        }
    }
    return 0;
}