/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \brief Example code on load and run TVM module.s
 * \file cpp_deploy.cc
 */
#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <cstdio>
#include <fstream>
#include </usr/include/opencv2/opencv.hpp>
#include </usr/include/opencv2/highgui/highgui.hpp>  
#include </usr/include/opencv2/imgproc/imgproc.hpp> 
#include <iostream>
#include <typeinfo>
#include <algorithm>
#include<vector>
#include<algorithm>


using namespace cv;

void Verify(tvm::runtime::Module mod, std::string fname) {
  // Get the function from the module.
  tvm::runtime::PackedFunc f = mod.GetFunction(fname);
  ICHECK(f != nullptr);
  // Allocate the DLPack data structures.
  //
  // Note that we use TVM runtime API to allocate the DLTensor in this example.
  // TVM accept DLPack compatible DLTensors, so function can be invoked
  // as long as we pass correct pointer to DLTensor array.
  //
  // For more information please refer to dlpack.
  // One thing to notice is that DLPack contains alignment requirement for
  // the data pointer and TVM takes advantage of that.
  // If you plan to use your customized data container, please
  // make sure the DLTensor you pass in meet the alignment requirement.
  //
  DLTensor* x;
  DLTensor* y;
  int ndim = 1;
  int dtype_code = kDLFloat;
  int dtype_bits = 32;
  int dtype_lanes = 1;
  int device_type = kDLCPU;
  int device_id = 0;
  int64_t shape[1] = {10};
  TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x);
  TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y);
  for (int i = 0; i < shape[0]; ++i) {
    static_cast<float*>(x->data)[i] = i;
  }
  // Invoke the function
  // PackedFunc is a function that can be invoked via positional argument.
  // The signature of the function is specified in tvm.build
  f(x, y);
  // Print out the output
  for (int i = 0; i < shape[0]; ++i) {
    ICHECK_EQ(static_cast<float*>(y->data)[i], i + 1.0f);
  }
  LOG(INFO) << "Finish verification...";
  TVMArrayFree(x);
  TVMArrayFree(y);
}

void DeploySingleOp() {
  // Normally we can directly
  tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so");
  LOG(INFO) << "Verify dynamic loading from test_addone_dll.so";
  Verify(mod_dylib, "addone");
  // For libraries that are directly packed as system lib and linked together with the app
  // We can directly use GetSystemLib to get the system wide library.
  LOG(INFO) << "Verify load function from system lib";
  tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))();
  Verify(mod_syslib, "addonesys");
}

void PreProcess(const Mat& image, Mat& image_blob)
{
	Mat input;
	image.copyTo(input);

	std::vector<Mat> channels, channel_p;
	split(input, channels);
	Mat R, G, B;
	B = channels.at(0);
	G = channels.at(1);
	R = channels.at(2);

	B = (B / 255. - 0.408) / 0.242;
	G = (G / 255. - 0.448) / 0.239;
	R = (R / 255. - 0.471) / 0.234;

	channel_p.push_back(R);
	channel_p.push_back(G);
	channel_p.push_back(B);

	Mat outt;
	merge(channel_p, outt);
	image_blob = outt;
}

void Mat_to_CHW(float *img_data, cv::Mat &frame)
{
    assert(img_data && !frame.empty());
    unsigned int volChl = 640 * 640;

    for(int c = 0; c < 3; ++c)
    {
        for (unsigned j = 0; j < volChl; ++j)
            img_data[c*volChl + j] = static_cast<float>(float(frame.data[j * 3 + c])/255.0);
    }

}


typedef struct BoxInfo
{
	float x1;
	float y1;
	float x2;
	float y2;
	float score;
	int label;
} BoxInfo;

void nms(vector<BoxInfo>& input_boxes)
{
        float nmsThreshold = 0.45;
	sort(input_boxes.begin(), input_boxes.end(), [](BoxInfo a, BoxInfo b) { return a.score > b.score; }); 
	vector<float> vArea(input_boxes.size());  
	for (int i = 0; i < input_boxes.size(); ++i)
	{
		vArea[i] = (input_boxes[i].x2 - input_boxes[i].x1 + 1)* (input_boxes[i].y2 - input_boxes[i].y1 + 1);
	}
	
	vector<bool> isSuppressed(input_boxes.size(), false);  
	for (int i = 0; i < input_boxes.size(); ++i)
	{
		if (isSuppressed[i]) { continue; }
		for (int j = i + 1; j < input_boxes.size(); ++j)
		{
			if (isSuppressed[j]) { continue; }
			float xx1 = max(input_boxes[i].x1, input_boxes[j].x1);
			float yy1 = max(input_boxes[i].y1, input_boxes[j].y1);
			float xx2 = min(input_boxes[i].x2, input_boxes[j].x2);
			float yy2 = min(input_boxes[i].y2, input_boxes[j].y2);
			float w = max(0.0f, xx2 - xx1 + 1);
			float h = max(0.0f, yy2 - yy1 + 1);
			float inter = w * h;	
			
			if(input_boxes[i].label == input_boxes[j].label)  
			{
				float ovr = inter / (vArea[i] + vArea[j] - inter);  
				if (ovr >= nmsThreshold)
				{
					isSuppressed[j] = true;
				}
			}	
		}
	}
	int idx_t = 0;   
	input_boxes.erase(remove_if(input_boxes.begin(), input_boxes.end(), [&idx_t, &isSuppressed](const BoxInfo& f) { return isSuppressed[idx_t++]; }), input_boxes.end());
}

void DeployGraphExecutor() {
  LOG(INFO) << "Running graph executor...";
  // load in the libr
  DLDevice dev{kDLROCM, 0};
  tvm::runtime::Module mod_factory = tvm::runtime::Module::LoadFromFile("lib/yolov5s_miopen_rocblas.so");
  
  // create the graph executor module
  using namespace std;
  tvm::runtime::Module gmod = mod_factory.GetFunction("default")(dev);
  cout<<"---------------"<<endl;

  tvm::runtime::PackedFunc set_input = gmod.GetFunction("set_input");
  tvm::runtime::PackedFunc get_output = gmod.GetFunction("get_output");
  tvm::runtime::PackedFunc run = gmod.GetFunction("run");
  cv::Mat image = cv::imread("./cow.jpg");
  //cv::Mat image = cv::imread("./bear.jpg");
  cv::Mat in_put;
  cv::Mat img_in;
  cv::resize(image, in_put, cv::Size(640, 640));
  //cv::cvtColor(frame, in_put, cv::COLOR_BGR2RGB);  

  float img_data[640*640*3];
//  PreProcess(in_put, img_in);
//  Mat_to_CHW(img_data, img_in);
  Mat_to_CHW(img_data, in_put);
  //int input_dtype_code = kDLFloat;
  //int input_dtype_bits = 32;
  //int input_dtype_lanes = 1;
  //DLDataType input_dtype = {input_dtype_code, input_dtype_bits, input_dtype_lanes};
  // Use the C++ API
  //tvm::runtime::NDArray x = tvm::runtime::NDArray::Empty({1, 3, 224, 224}, input_dtype, {kDLROCM, 0});
  //tvm::runtime::NDArray input_data = tvm::runtime::NDArray::Empty({1, 3, 224, 224}, DLDataType{kDLFloat, 32, 1}, dev);
  //tvm::runtime::NDArray y = tvm::runtime::NDArray::Empty({1, 1000}, DLDataType{kDLFloat, 32, 1}, {kDLROCM,0});


  
  DLTensor* y;
  int out_ndim = 3;
  int64_t out_shape[3] = {1, 25200, 85};
  int dtype_code = kDLFloat;
  int dtype_bits = 32;
  int dtype_lanes = 1;
  int device_type = kDLROCM;
  int device_id = 0;
  TVMArrayAlloc(out_shape, out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y);
  


  
  DLTensor* x;
  int ndim = 4;
  //int dtype_code = kDLFloat;
  //int dtype_bits = 32;
  //int dtype_lanes = 1;
  //int device_type = kDLROCM;
  //int device_id = 0;
  int64_t shape[4] = {1, 3 ,640, 640};
  TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x);
  memcpy(x->data,&img_data,3*640*640*sizeof(float));
  

  //TVMArrayCopyFromBytes(x,&img_data,1*sizeof(float));
  //DLTensor* x;
  //TVMArrayAlloc({1, 3, 224, 224}, 4, kDLFloat, 32, 1, kDLROCM, 0, &x);
 /* 
  for (int c = 0; c < 3; ++c) {
    for (int h=0; h<224; ++h){
      for(int w=0; w<224; ++w){
        static_cast<float*>(x->data)[w] = 0;
   }
   }
  }
  */
  // set the right input
  set_input("images", x);
  // run the code
  run();
  //get the output
  
  get_output(0, y);
  /*
  for (int i = 0; i < 1; ++i) {
    for (int j = 0; j < 1000; ++j) {
      ICHECK_EQ(static_cast<float*>(y->data)[i * 1 + j], i * 1 + j + 1);
    }
  }
  */
 
 static float result[25200][85] = {0};
 TVMArrayCopyToBytes(y, result, 25200 * 85 * sizeof(float));
 int num_proposal = sizeof(result)/sizeof(result[0]); //25200
 int box_classes = sizeof(result[0])/sizeof(result[0][0]);//85
 cout<<"num_proposal:"<<num_proposal<<endl;
 cout<<"box_classes："<<box_classes<<endl;
 vector<BoxInfo> generate_boxes;  // BoxInfo自定义的结构体
 float* pdata = result[0];
 //YOLOv5 detect(pdata);
 float ratioh=1,ratiow=1;
 //float ratioh = (float)image.rows / 640, ratiow = (float)image.cols / 640;
 cout<<"ratioh:"<<ratioh<<"\nratiow:"<<ratiow<<endl;
 float objThreshold=0.2, confThreshold=0.6;
 //vector<float> confidences;
 //vector<Rect> boxes;
 //vector<int> classIds;
 float padw=0,padh=0;
 for(int i=0;i<num_proposal;i++)
 {
  int index = i*box_classes;
  float obj_conf = pdata[index+4];  //置信度分数
  //cout<<pdata[i]<<endl;
  //cout<<"obj_conf:"<<obj_conf<<endl;
  //cout<<"+"<<endl;
  if(obj_conf>objThreshold)
  {
    cout<<"obj_conf"<<obj_conf<<endl;
    //Mat scores(1, box_classes-5, CV_32FC1, pdata+index + 5);
    //Point classIdPoint; //定义点
    int class_idx = 0;
    float max_class_socre = 0;
    for (int k = 0; k < 80; ++k)
	{
	  if (pdata[k + index + 5] > max_class_socre)
		{
		   max_class_socre = pdata[k + index + 5];
		   class_idx = k;
		}
	}
    //double max_class_socre; // 定义一个double类型的变量保存预测中类别分数最大值
    //minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);  // 求每类类别分数最大的值和索引
    //cout<<"max_score"<<max_class_socre<<endl;
    //max_class_socre *= obj_conf;
    //cout<<"max_class_socre: "<<max_class_socre<<endl;
    if (max_class_socre > confThreshold){
      //const int class_idx = classIdPoint.x;
      float cx = pdata[index];
      float cy = pdata[index+1];
      float w = pdata[index+2];
      float h = pdata[index+3];
      float xmin = ((cx - padw - 0.5 * w)*ratiow);  // *ratiow，变回原图尺寸
      float ymin = ((cy - padh - 0.5 * h)*ratioh);    
      float xmax = (cx - padw + 0.5 * w)*ratiow;
      float ymax = (cy - padh + 0.5 * h)*ratioh;
      generate_boxes.push_back(BoxInfo{ xmin, ymin, xmax, ymax, max_class_socre, class_idx });
      //confidences.push_back((float)max_class_socre);
      //boxes.push_back(Rect(left, top, (int)(w*ratiow), (int)(h*ratioh)));  //（x,y,w,h）
      //classIds.push_back(class_idx);  //     
      //cout<<"cx:"<<cx<<endl;
      
    }
  }
 }
 //vector<int> indices;
 //float nmsThreshold = 0.1;
 nms(generate_boxes);
 cout<<generate_boxes.size()<<endl;
 for(size_t i=0;i<generate_boxes.size();i++){
 float xmin = generate_boxes[i].x1;
 float xmax = generate_boxes[i].x2;
 float ymin = generate_boxes[i].y1;
 float ymax = generate_boxes[i].y2;
 float score = generate_boxes[i].score;
 int classes = generate_boxes[i].label;
 rectangle(in_put, Point(xmin, ymin), Point(int(generate_boxes[i].x2), int(generate_boxes[i].y2)), Scalar(0, 0, 255), 2);
 string label = format("%.2f", generate_boxes[i].score);
 putText(in_put, label, Point(xmin, ymin - 5), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0), 1);
 //imwrite("result.jpg",in_put)
 cout<<"xmin:"<<xmin<<endl;
 cout<<"xmax:"<<xmax<<endl;
 cout<<"ymin:"<<ymin<<endl;
 cout<<"ymax:"<<ymax<<endl;
 cout<<"score:"<<score<<endl;
 cout<<"classes:"<<classes<<endl;
 }
 
  cout<<"----------"<<endl;
  imwrite("result.jpg",in_put);

}

int main(void) {
  //DeploySingleOp();
  DeployGraphExecutor();
  return 0;
}
