# -*- coding: utf-8 -*-
import cv2
import numpy as np
import os
import argparse
import time
import migraphx


class YOLOv5:
    def __init__(self, path, dynamic=False, obj_thres=0.5, conf_thres=0.25, iou_thres=0.5):
        self.objectThreshold = obj_thres
        self.confThreshold = conf_thres
        self.nmsThreshold = iou_thres
        self.isDynamic = dynamic
        # 获取模型检测的类别信息
        self.classNames = list(map(lambda x: x.strip(), open('../Resource/Models/coco.names', 'r').readlines()))

        # 解析推理模型
        if self.isDynamic:
            maxInput={"images":[1,3,800,800]}
            self.model = migraphx.parse_onnx(path, map_input_dims=maxInput)
            
            self.inputName = self.model.get_parameter_names()[0]
            inputShape = self.model.get_parameter_shapes()[self.inputName].lens()
            print("inputName:{0} \ninputMaxShape:{1}".format(self.inputName, inputShape))
        else:
            self.model = migraphx.parse_onnx(path) 
            self.inputName = self.model.get_parameter_names()[0]
            inputShape = self.model.get_parameter_shapes()[self.inputName].lens()
            print("inputName:{0} \ninputShape:{1}".format(self.inputName, inputShape))
            
            # 静态推理尺寸
            self.inputWidth = inputShape[3]
            self.inputHeight = inputShape[2]  
        
        # 模型编译
        self.model.compile(t=migraphx.get_target("gpu"), device_id=0)  # device_id: 设置GPU设备，默认为0号设备
        print("Success to compile")

    def detect(self, image, input_shape=None):
        if(self.isDynamic):
            self.inputWidth = input_shape[3]
            self.inputHeight = input_shape[2]
        # 输入图片预处理
        input_img = self.prepare_input(image)

        # 执行推理
        start = time.time()
        result = self.model.run({self.model.get_parameter_names()[0]: input_img})
        print('net forward time: {:.4f}'.format(time.time() - start))
        # 模型输出结果后处理
        boxes, scores, class_ids = self.process_output(result)

        return boxes, scores, class_ids

    def prepare_input(self, image):
        self.img_height, self.img_width = image.shape[:2]
        input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        input_img = cv2.resize(input_img, (self.inputWidth, self.inputHeight))
        input_img = input_img.transpose(2, 0, 1)
        input_img = np.expand_dims(input_img, 0)
        input_img = np.ascontiguousarray(input_img)
        input_img = input_img.astype(np.float32)
        input_img = input_img / 255

        return input_img

    def process_output(self, output):
        predictions = np.squeeze(output[0])

        # 筛选包含物体的anchor
        obj_conf = predictions[:, 4]
        predictions = predictions[obj_conf > self.objectThreshold]
        obj_conf = obj_conf[obj_conf > self.objectThreshold]

        # 筛选大于置信度阈值的anchor
        predictions[:, 5:] *= obj_conf[:, np.newaxis]
        scores = np.max(predictions[:, 5:], axis=1)
        valid_scores = scores > self.confThreshold
        predictions = predictions[valid_scores]
        scores = scores[valid_scores]

        # 获取最高置信度分数对应的类别ID
        class_ids = np.argmax(predictions[:, 5:], axis=1)

        # 获取每个物体对应的anchor
        boxes = self.extract_boxes(predictions)

        # 执行非极大值抑制消除冗余anchor
        indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), self.confThreshold, self.nmsThreshold).flatten()

        return boxes[indices], scores[indices], class_ids[indices]

    def extract_boxes(self, predictions):
        boxes = predictions[:, :4]
        boxes = self.rescale_boxes(boxes)
        boxes_ = np.copy(boxes)
        boxes_[..., 0] = boxes[..., 0] - boxes[..., 2] * 0.5
        boxes_[..., 1] = boxes[..., 1] - boxes[..., 3] * 0.5
        return boxes_

    def rescale_boxes(self, boxes):
        # 对anchor尺寸进行变换
        input_shape = np.array([self.inputWidth, self.inputHeight, self.inputWidth, self.inputHeight])
        boxes = np.divide(boxes, input_shape, dtype=np.float32)
        boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height])
        return boxes

    def draw_detections(self, image, boxes, scores, class_ids):
        for box, score, class_id in zip(boxes, scores, class_ids):
            cx, cy, w, h = box.astype(int)

            # 绘制检测物体框
            cv2.rectangle(image, (cx, cy), (cx + w, cy + h), (0, 255, 255), thickness=2)
            label = self.classNames[class_id]
            label = f'{label} {score:.2f}'
            labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
            cv2.putText(image, label, (cx, cy - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), thickness=2)
        return image
    
def read_images(image_path):
    image_lists = []
    
    for image_name in os.listdir(image_path):
        image = cv2.imread(image_path +"/" + image_name, 1)
        image_lists.append(image)
        
    return image_lists

def yolov5_Static(imgpath, modelpath, objectThreshold, confThreshold, nmsThreshold):
    yolov5_detector = YOLOv5(modelpath, False, obj_thres=objectThreshold, conf_thres=confThreshold,
                             iou_thres=nmsThreshold)
    srcimg = cv2.imread(imgpath, 1)

    boxes, scores, class_ids = yolov5_detector.detect(srcimg)

    dstimg = yolov5_detector.draw_detections(srcimg, boxes, scores, class_ids)

    # 保存检测结果
    cv2.imwrite("./Result.jpg", dstimg)
    print("Success to save result")


def yolov5_dynamic(imgpath, modelpath, objectThreshold, confThreshold, nmsThreshold):
    # 设置动态输入shape
    input_shapes = []
    input_shapes.append([1,3,416,416])
    input_shapes.append([1,3,608,608])
    
    # 读取测试图像
    image_lists = read_images(imgpath)
    
    # 推理
    yolov5_detector = YOLOv5(modelpath, True, obj_thres=objectThreshold, 
                                    conf_thres=confThreshold, iou_thres=nmsThreshold)
    for i, image in enumerate(image_lists):
        print("Start to inference image{}".format(i))
        boxes, scores, class_ids = yolov5_detector.detect(image, input_shapes[i])
        dstimg = yolov5_detector.draw_detections(image, boxes, scores, class_ids)
        
        # 保存检测结果
        result_name = "Result{}.jpg".format(i)
        cv2.imwrite(result_name, dstimg)
    
    print("Success to save results")
    
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgPath', type=str, default='../Resource/Images/DynamicPics/image1.jpg', help="image path")
    parser.add_argument('--imgFolderPath', type=str, default='../Resource/Images/DynamicPics', help="image folder path")
    parser.add_argument('--staticModelPath', type=str, default='../Resource/Models/yolov5s.onnx', help="static onnx filepath")
    parser.add_argument('--dynamicModelPath', type=str, default='../Resource/Models/yolov5s_Nx3xNxN.onnx', help="static onnx filepath")
    parser.add_argument('--objectThreshold', default=0.5, type=float, help='class confidence')
    parser.add_argument('--confThreshold', default=0.25, type=float, help='class confidence')
    parser.add_argument('--nmsThreshold', default=0.5, type=float, help='nms iou thresh')
    parser.add_argument("--staticInfer",action="store_true",default=False,help="Performing static inference")
    parser.add_argument("--dynamicInfer",action="store_true",default=False,help="Performing static inference")
    args = parser.parse_args()
    
    # 静态推理
    if args.staticInfer:
        yolov5_Static(args.imgPath, args.staticModelPath, args.objectThreshold, args.confThreshold, args.nmsThreshold)
    # 动态推理
    if args.dynamicInfer:
        yolov5_dynamic(args.imgFolderPath, args.dynamicModelPath, args.objectThreshold, args.confThreshold, args.nmsThreshold)

    












