import yaml
import json
import argparse
import re
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from utils.general import non_max_suppression, scale_coords  # tag > 2.0
from utils.datasets import create_dataloader

def coco80_to_coco91_class():
    # converts 80-index (val2014/val2017) to 91-index (paper)
    x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
         64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
    return x
    
def correct_bbox(result, anchors, stride, cls_num, out):
    result = torch.tensor(result)
    bs, _, ny, nx, _ = result.shape
    grid, anchor_grid = make_grid(anchors, nx, ny)
    y = result.float().sigmoid()
    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid) * stride  # xy
    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid  # wh
    out.append(y.view(bs, -1, cls_num+5))

def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y


def save_coco_json(predn, pred_dict, image_id, class_map):
    # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
    box = xyxy2xywh(predn[:, :4])  # xywh
    box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
    for p, b in zip(predn.tolist(), box.tolist()):
        pred_dict.append({'image_id': image_id,
                          'category_id': class_map[int(p[5])],
                          'bbox': [round(x, 3) for x in b],
                          'score': round(p[4], 5)})

def evaluate(cocoGt_file, cocoDt_file):
    cocoGt = COCO(cocoGt_file)
    cocoDt = cocoGt.loadRes(cocoDt_file)
    cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()


def postprocess(opt, cfg):
    outputs=[0,0,0,0] # output len = 4
    pred_results = []

    # load dataset
    single_cls = False
    dataloader = create_dataloader(f"{opt.data_path}/val2017.txt", opt.img_size, opt.batch_size,
                                   max(cfg["stride"]), single_cls, pad=0.5)[0]
    output_file = os.listdir(opt.output)[0]
    
    i = 0
    for (img, targets, paths, shapes) in tqdm(dataloader):
        if output_file.split(".")[1] == "npy":
            out_filepath = f"{opt.output}/{i}_0.npy"
            data = np.load(out_filepath)
        elif output_file.split(".")[1] == "bin":
            out_filepath = f"{opt.output}/{i}_0.bin"
            #data = np.fromfile(out_filepath, dtype=np.float16)
            data = np.fromfile(out_filepath, dtype=np.float32)
        else:
            print("The file extension is wrong")
            break
        box_out = torch.tensor(data.reshape(len(shapes), -1, 85))

        # non_max_suppression
        boxout = nms(box_out, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])

        for idx, pred in enumerate(boxout):
            try:
                scale_coords((640, 640), pred[:, :4], shapes[idx][0][:])

            except:
                pred = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
            # append to COCO-JSON dictionary
            path = Path(paths[idx])
            image_id = int(path.stem) if path.stem.isnumeric() else path.stem
            save_coco_json(pred, pred_results, image_id, coco80_to_coco91_class())
        
        i += 1

    pred_json_file = f"yolov5m_predictions.json"

    with open(pred_json_file, 'w') as f:
        json.dump(pred_results, f)
    print(f"saving results to {pred_json_file}")

    # evaluate mAP
    evaluate(opt.ground_truth_json, pred_json_file)


def nms(box_out, conf_thres=0.4, iou_thres=0.5):
    try:
        boxout = non_max_suppression(box_out, conf_thres=conf_thres, iou_thres=iou_thres, multi_label=True)
    except:
        boxout = non_max_suppression(box_out, conf_thres=conf_thres, iou_thres=iou_thres)

    return boxout


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='YOLOv5 offline model inference.')
    parser.add_argument('--ground_truth_json', type=str, default="coco/instances_val2017.json",
                        help='annotation file path')
    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
    parser.add_argument('--cfg_file', type=str, default='model.yaml', help='model parameters config file')
    parser.add_argument('--output', type=str, help='ais_bench inference output path')
    parser.add_argument('--img_info', type=str, default="img_info.json", help='ais_bench inference output path')
    parser.add_argument('--data_path', type=str, default="coco", help='root dir for val images and annotations')
    parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')

    opt = parser.parse_args()

    with open(opt.cfg_file) as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    postprocess(opt, cfg)
