#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <YOLOV9.h>

void MIGraphXSamplesUsage(char *programName)
{
    printf("Usage : %s <index> \n", programName);
    printf("index:\n");
    printf("\t 0) YOLOV9 sample.\n");
    printf("\t 1) YOLOV9 Dynamic sample.\n");
}

void Sample_YOLOV9();
void Sample_YOLOV9_Dynamic();

int main(int argc, char *argv[])
{
    if (argc < 2 || argc > 2)
    {
        MIGraphXSamplesUsage(argv[0]);
        return -1;
    }
    if (!strncmp(argv[1], "-h", 2))
    {
        MIGraphXSamplesUsage(argv[0]);
        return 0;
    }
    switch (*argv[1])
    {
    case '0':
    {
        Sample_YOLOV9();
        break;
    }
    case '1':
    {
        Sample_YOLOV9_Dynamic();
        break;
    }
    default:
    {
        MIGraphXSamplesUsage(argv[0]);
        break;
    }
    }
    return 0;
}

void Sample_YOLOV9()
{
    // 创建YOLOV9检测器
    migraphxSamples::DetectorYOLOV9 detector;
    migraphxSamples::InitializationParameterOfDetector initParamOfDetectorYOLOV9;
    initParamOfDetectorYOLOV9.configFilePath = CONFIG_FILE;
    migraphxSamples::ErrorCode errorCode = detector.Initialize(initParamOfDetectorYOLOV9, false);
    if (errorCode != migraphxSamples::SUCCESS)
    {
        LOG_ERROR(stdout, "fail to initialize detector!\n");
        exit(-1);
    }
    LOG_INFO(stdout, "succeed to initialize detector\n");

    // 读取测试图片
    cv::Mat srcImage = cv::imread("../Resource/Images/image_test.jpg", 1);

    // 静态推理固定尺寸
    std::vector<std::size_t> inputShape = {1, 3, 640, 640};

    // 推理
    std::vector<migraphxSamples::ResultOfDetection> predictions;
    double time1 = cv::getTickCount();
    detector.Detect(srcImage, inputShape, predictions, false);
    double time2 = cv::getTickCount();
    double elapsedTime = (time2 - time1) * 1000 / cv::getTickFrequency();
    LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);

    // 获取推理结果
    LOG_INFO(stdout, "========== Detection Results ==========\n");
    for (int i = 0; i < predictions.size(); ++i)
    {
        migraphxSamples::ResultOfDetection result = predictions[i];
        cv::rectangle(srcImage, result.boundingBox, cv::Scalar(0, 255, 255), 2);

        std::string label = cv::format("%.2f", result.confidence);
        label = result.className + " " + label;
        int left = predictions[i].boundingBox.x;
        int top = predictions[i].boundingBox.y;
        int baseLine;
        cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
        top = max(top, labelSize.height);
        cv::putText(srcImage, label, cv::Point(left, top - 10), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 255, 255), 2);

        LOG_INFO(stdout, "box:%d %d %d %d,label:%d,confidence:%f\n", predictions[i].boundingBox.x,
                 predictions[i].boundingBox.y, predictions[i].boundingBox.width, predictions[i].boundingBox.height, predictions[i].classID, predictions[i].confidence);
    }
    cv::imwrite("Result.jpg", srcImage);
    LOG_INFO(stdout, "Detection results have been saved to ./Result.jpg\n");
}

void Sample_YOLOV9_Dynamic()
{
    // 创建YOLOV9检测器
    migraphxSamples::DetectorYOLOV9 detector;
    migraphxSamples::InitializationParameterOfDetector initParamOfDetectorYOLOV9;
    initParamOfDetectorYOLOV9.configFilePath = CONFIG_FILE;
    migraphxSamples::ErrorCode errorCode = detector.Initialize(initParamOfDetectorYOLOV9, true);
    if (errorCode != migraphxSamples::SUCCESS)
    {
        LOG_ERROR(stdout, "fail to initialize detector!\n");
        exit(-1);
    }
    LOG_INFO(stdout, "succeed to initialize detector\n");

    // 读取测试图像
    std::vector<cv::Mat> srcImages;
    cv::String folder = "../Resource/Images/DynamicPics";
    std::vector<cv::String> imagePathList;
    cv::glob(folder, imagePathList);
    for (int i = 0; i < imagePathList.size(); ++i)
    {
        cv::Mat srcImage = cv::imread(imagePathList[i], 1);
        srcImages.push_back(srcImage);
    }

    // 设置动态推理shape
    std::vector<std::vector<std::size_t>> inputShapes;
    inputShapes.push_back({1, 3, 416, 416});
    inputShapes.push_back({1, 3, 608, 608});

    for (int i = 0; i < srcImages.size(); ++i)
    {
        // 推理
        std::vector<migraphxSamples::ResultOfDetection> predictions;
        double time1 = cv::getTickCount();
        detector.Detect(srcImages[i], inputShapes[i], predictions, true);
        double time2 = cv::getTickCount();
        double elapsedTime = (time2 - time1) * 1000 / cv::getTickFrequency();
        LOG_INFO(stdout, "inference image%d time:%f ms\n", i, elapsedTime);

        // 获取推理结果
        LOG_INFO(stdout, "========== Detection Image%d Results ==========\n", i);
        for (int j = 0; j < predictions.size(); ++j)
        {
            migraphxSamples::ResultOfDetection result = predictions[j];
            cv::rectangle(srcImages[i], result.boundingBox, cv::Scalar(0, 255, 255), 2);

            std::string label = cv::format("%.2f", result.confidence);
            label = result.className + " " + label;
            int left = predictions[j].boundingBox.x;
            int top = predictions[j].boundingBox.y;
            int baseLine;
            cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
            top = max(top, labelSize.height);
            cv::putText(srcImages[i], label, cv::Point(left, top - 10), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 255, 255), 2);

            LOG_INFO(stdout, "box:%d %d %d %d,label:%d,confidence:%f\n", predictions[j].boundingBox.x,
                     predictions[j].boundingBox.y, predictions[j].boundingBox.width, predictions[j].boundingBox.height, predictions[j].classID, predictions[j].confidence);
        }
        std::string imgName = cv::format("Result%d.jpg", i);
        cv::imwrite(imgName, srcImages[i]);
        LOG_INFO(stdout, "Detection results have been saved to ./Result%d.jpg\n", i);
    }
}