#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <cstdio>
#include <fstream>
#include </usr/include/opencv2/opencv.hpp>
#include </usr/include/opencv2/highgui/highgui.hpp>  
#include </usr/include/opencv2/imgproc/imgproc.hpp> 
#include <iostream>
#include <typeinfo>
#include <algorithm>
#include <vector>
#include <algorithm>
#include <Decoder.h>
#include <Queuethread.h>
#include <thread>

using namespace cv;
using namespace std;

static void DecoderThreadFunc(Queue* queue)
{
    int ret, end = 0;
    int frame_cnt = 0;
    Queue* que = queue;
    Decoder decoder;
    decoder.DecoderInit();
    while(true)
    {
        if (av_read_frame(decoder.fmt_ctx, decoder.pkt) < 0)
        {
            if(end == 2)
            {
                que->DecodeEnd = true;
                break;
            }
           end = 1;
        }
        if (decoder.pkt->stream_index == decoder.video_stream_idx) {
            if(!end) {
                ret = avcodec_send_packet(decoder.video_dec_ctx, decoder.pkt);
            } else {
                ret = avcodec_send_packet(decoder.video_dec_ctx, NULL);
            }
            if (ret < 0 && ret != AVERROR_EOF) {
                fprintf(stderr, "Error submitting a packet for decoding\n");
                que->DecodeEnd = true;
                break;
            }
            while (ret >= 0 || end == 1)
            {
                ret = avcodec_receive_frame(decoder.video_dec_ctx, decoder.frame);
                if (ret == AVERROR(EAGAIN)) {
                    break;
                } else if (ret == AVERROR_EOF ) {
                    end = 2;
                    break;
                } else if (ret < 0) {
                    av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n");
                    que->finish();
                    return;
                }
                decoder.frame->pts = decoder.frame->best_effort_timestamp;
                frame_cnt++;

                cv::Mat srcImage = cv::Mat::zeros(decoder.frame->height*3/2, decoder.frame->width, CV_8UC1);
                memcpy(srcImage.data, (unsigned char*)decoder.frame->data[0], decoder.frame->width * decoder.frame->height);
                memcpy(srcImage.data + decoder.frame->width * decoder.frame->height, (unsigned char*)decoder.frame->data[1], decoder.frame->width * decoder.frame->height/4);
                memcpy(srcImage.data + decoder.frame->width * decoder.frame->height*5/4, (unsigned char*)decoder.frame->data[2], decoder.frame->width * decoder.frame->height/4);
                cvtColor(srcImage, srcImage, COLOR_YUV420p2RGB);
                que->enQueue(srcImage);
                av_frame_unref(decoder.frame);
            }
        }
        av_packet_unref(decoder.pkt);
    }
    fprintf(stdout, "Decoder: ####### frame count: %d\n", frame_cnt);
    que->finish();
}

void Mat_to_CHW(float *img_data, cv::Mat &frame)
{
    assert(img_data && !frame.empty());
    unsigned int volChl = 416 * 416;

    for(int c = 0; c < 3; ++c)
    {
        for (long j = 0; j < volChl; ++j)
            img_data[c*volChl + j] = static_cast<float>(float(frame.data[j * 3 + c])/255.0);
    }
}

typedef struct BoxInfo
{
    float x1;
    float y1;
    float x2;
    float y2;
    float score;
    int label;
} BoxInfo;

void nms(vector<BoxInfo>& input_boxes)
{
    float nmsThreshold = 0.45;
    sort(input_boxes.begin(), input_boxes.end(), [](BoxInfo a, BoxInfo b) { return a.score > b.score; }); 
    vector<float> vArea(input_boxes.size());  
    for (int i = 0; i < input_boxes.size(); ++i)
    {
        vArea[i] = (input_boxes[i].x2 - input_boxes[i].x1 + 1)* (input_boxes[i].y2 - input_boxes[i].y1 + 1);
    }

    vector<bool> isSuppressed(input_boxes.size(), false);  
    for (int i = 0; i < input_boxes.size(); ++i)
    {
        if (isSuppressed[i]) { continue; }
        for (int j = i + 1; j < input_boxes.size(); ++j)
        {
            if (isSuppressed[j]) { continue; }
            float xx1 = max(input_boxes[i].x1, input_boxes[j].x1);
            float yy1 = max(input_boxes[i].y1, input_boxes[j].y1);
            float xx2 = min(input_boxes[i].x2, input_boxes[j].x2);
            float yy2 = min(input_boxes[i].y2, input_boxes[j].y2);
            float w = max(0.0f, xx2 - xx1 + 1);
            float h = max(0.0f, yy2 - yy1 + 1);
            float inter = w * h;	

            if(input_boxes[i].label == input_boxes[j].label)  
            {
                float ovr = inter / (vArea[i] + vArea[j] - inter);  
                if (ovr >= nmsThreshold)
                {
                    isSuppressed[j] = true;
                }
            }	
        }
    }
    int idx_t = 0;   
    input_boxes.erase(remove_if(input_boxes.begin(), input_boxes.end(), [&idx_t, &isSuppressed](const BoxInfo& f) { return isSuppressed[idx_t++]; }), input_boxes.end());
}

void DeployGraphExecutor(cv::Mat &srcImage, std::vector<BoxInfo> &generate_boxes)
{
    // load in the libr
    DLDevice dev{kDLROCM, 0};
    tvm::runtime::Module mod_factory = tvm::runtime::Module::LoadFromFile("lib/yolov3t_miopen_rocblas.so");
  
    // create the graph executor module
    tvm::runtime::Module gmod = mod_factory.GetFunction("default")(dev);

    tvm::runtime::PackedFunc set_input = gmod.GetFunction("set_input");
    tvm::runtime::PackedFunc get_output = gmod.GetFunction("get_output");
    tvm::runtime::PackedFunc run = gmod.GetFunction("run");
    cv::Mat in_put;
    cv::resize(srcImage, in_put, cv::Size(416, 416));
    float img_data[416*416*3];
    Mat_to_CHW(img_data, in_put);
  
    DLTensor* y;
    int out_ndim = 3;
    int64_t out_shape[3] = {1, 2535, 85};
    int dtype_code = kDLFloat;
    int dtype_bits = 32;
    int dtype_lanes = 1;
    int device_type = kDLROCM;
    int device_id = 0;
    TVMArrayAlloc(out_shape, out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y);

    DLTensor* x;
    int ndim = 4;
    int64_t shape[4] = {1, 3 ,416, 416};
    TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x);
    memcpy(x->data,&img_data,3*416*416*sizeof(float));

    // set the right input
    set_input("images", x);
    // run the code
    run();
    //get the output
    get_output(0, y);

    static float result[2535][85] = {0};
    TVMArrayCopyToBytes(y, result, 2535 * 85 * sizeof(float));
    int num_proposal = sizeof(result)/sizeof(result[0]); //2535
    int box_classes = sizeof(result[0])/sizeof(result[0][0]);//85
    float* pdata = result[0];
    float ratioh = (float)srcImage.rows / 416, ratiow = (float)srcImage.cols / 416;
    float objThreshold=0.5, confThreshold=0.5;
    for(int i=0;i<num_proposal;i++)
    {
        int index = i*box_classes;
        float obj_conf = pdata[index+4];  //置信度分数
        if(obj_conf > objThreshold)
        {
           int class_idx = 0;
           float max_class_socre = 0;
           for (int k = 0; k < 80; ++k)
           {
               if (pdata[k + index + 5] > max_class_socre)
               {
                   max_class_socre = pdata[k + index + 5];
                   class_idx = k;
               }
           }
           if (max_class_socre > confThreshold)
           {
               float cx = pdata[index];
               float cy = pdata[index+1];
               float w = pdata[index+2];
               float h = pdata[index+3];
               float xmin = (cx - 0.5 * w)*ratiow;  // *ratiow，变回原图尺寸
               float ymin = (cy - 0.5 * h)*ratioh;    
               float xmax = (cx + 0.5 * w)*ratiow;
               float ymax = (cy + 0.5 * h)*ratioh;
               generate_boxes.push_back(BoxInfo{ xmin, ymin, xmax, ymax, max_class_socre, class_idx });
           }
        }
    }
    nms(generate_boxes);
}

int main()
{
    Queue* queue = new Queue(1);
    std::thread ThreadDecoder(DecoderThreadFunc, queue);

    int frame_cnt = 0;
    double start_time = getTickCount();
    while (!queue->DecodeEnd)
    {
        cv::Mat srcImage;
        queue->deQueue(&srcImage);
	if (srcImage.empty()) {
	    continue;
	}

        std::vector<BoxInfo> generate_boxes;
        double time1 = getTickCount();
        DeployGraphExecutor(srcImage, generate_boxes);
        double time2 = getTickCount();
	double elapsedTime = (time2 - time1)*1000 / getTickFrequency();
	fprintf(stdout, "inference time:%f ms\n", elapsedTime);
	frame_cnt++;

        // postprocess
        fprintf(stdout,"////////////////Detection Results////////////////\n");
        for(size_t i=0;i<generate_boxes.size();i++)
        {
            BoxInfo result = generate_boxes[i];
            rectangle(srcImage, Point(result.x1, result.y1), Point(int(result.x2), int(result.y2)), Scalar(0, 0, 255), 2);
            string score = format("score:%.2f", result.score);
            string label = format("label:%d", result.label);
            putText(srcImage, score, Point(result.x1, result.y1 - 5), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0), 1);
            putText(srcImage, label, Point(result.x1 + 130, result.y1 - 5), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0), 1);

            fprintf(stdout, "detector result[%d] box:%.1f %.1f %.1f %.1f, label:%d, confidence:%.2f\n", i,
            result.x1, result.x2, result.y1, result.y2, result.label, result.score);
        }
        // save result
        /*char out[20];
        snprintf(out, sizeof(out), "out/Frame%d.jpg", frame_cnt);
        imwrite(out, srcImage);*/
    }
    double end_time = getTickCount();
    fprintf(stdout, "Finish ####### frame_cnt: %d, Inference fps: %.2f, all time: %.2f ms\n", frame_cnt, float(frame_cnt/((end_time - start_time)/getTickFrequency())), (end_time - start_time)/getTickFrequency()*1000);

    ThreadDecoder.join();
    delete queue;
    queue = NULL;
    return 0;
}
