import os
import sys
import cv2

path = os.path.dirname(__file__)
CENTERNET_PATH = os.path.join(path, '../src/lib')
sys.path.insert(0, CENTERNET_PATH)

from opts_pose import opts
from detectors.detector_factory import detector_factory

import scipy.io as sio


def test_img(model_path, debug, threshold=0.4):
    TASK = 'multi_pose'
    input_h, intput_w = 800, 800
    opt = opts().init('--task {} --load_model {} --debug {} --input_h {} --input_w {}'.format(
        TASK, model_path, debug, intput_w, input_h).split(' '))

    detector = detector_factory[opt.task](opt)
    img_path = '../test_img/000388.jpg'
    ori_img = cv2.imread(img_path, -1)
    res = detector.run(ori_img)['results']
    draw_img = ori_img.copy()

    for b in res[1]:
        x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
        if s >= threshold:
            cv2.rectangle(draw_img, (int(x1), int(y1)),
                          (int(x2), int(y2)), (0, 0, 255))
            cv2.putText(draw_img, "Face:"+str(s)
                        [:3], (int(x1)-2, int(y1)-2), 0, 0.5, (255, 255, 255), 1)
    cv2.imwrite("./draw_img.jpg", draw_img)
    print("end.")


def test_vedio(model_path, debug, vedio_path=None):
    debug = -1  # return the result image with draw
    TASK = 'multi_pose'
    vis_thresh = 0.45
    input_h, intput_w = 800, 800
    opt = opts().init('--task {} --load_model {} --debug {} --input_h {} --input_w {} --vis_thresh {}'.format(
        TASK, model_path, debug, intput_w, input_h, vis_thresh).split(' '))
    detector = detector_factory[opt.task](opt)

    vedio = vedio_path if vedio_path else 0
    cap = cv2.VideoCapture(vedio)
    while cap.isOpened():
        det = cap.grab()
        if det:
            flag, frame = cap.retrieve()
            res = detector.run(frame)
            cv2.imshow('face detect', res['plot_img'])

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()


def test_wider_Face(model_path, debug, threshold=0.05):
    from progress.bar import Bar
    Path = '../datasets/images/val' # WIDER_val/images path
    wider_face_mat = sio.loadmat('../evaluate/ground_truth/wider_face_val.mat')
    event_list = wider_face_mat['event_list']
    file_list = wider_face_mat['file_list']
    print("*** event_list", event_list)

    TASK = 'multi_pose'
    input_h, intput_w = 800, 800
    opt = opts().init('--task {} --load_model {} --debug {} --vis_thresh {} --input_h {} --input_w {}'.format(
        TASK, model_path, debug, threshold, input_h, intput_w).split(' '))
    detector = detector_factory[opt.task](opt)

    save_path = '../output/widerface/'
    for index, event in enumerate(event_list):
        file_list_item = file_list[index][0]
        im_file_dir = event[0][0]

        if not os.path.exists(save_path + im_file_dir):
            os.makedirs(save_path + im_file_dir)

        bar1 = Bar("Testing", max=len(file_list_item))
        for num, file in enumerate(file_list_item):
            im_name = file[0][0]

            im_zip_name = '{}/{}.jpg'.format(im_file_dir, im_name)

            img_path = os.path.join(Path, im_zip_name)
            ori_img = cv2.imread(img_path)
            if ori_img is None:
                print("*** img_path {} is empty!".format(img_path))
                continue
            dets = detector.run(ori_img)['results']
            f = open(save_path + im_file_dir + '/' + im_name + '.txt', 'w')
            f.write('{:s}\n'.format('%s/%s.jpg' % (im_file_dir, im_name)))
            f.write('{:d}\n'.format(len(dets)))
            for b in dets[1]:
                x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
                f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(
                    x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s))
            f.close()
            Bar.suffix = 'event:%d num:%d' % (index + 1, num + 1)
            bar1.next()


if __name__ == '__main__':
    '''
    debug = 0 # return the detect result without show
    debug = 1 # draw and show the result image
    debug = -1  # return the result image with draw
    '''
    debug = 0
    model_path = '../models/model_best.pth'  # or your model path
    # 单图测试
    test_img(model_path, debug)
    # 视频测试
    # test_vedio(model_path, debug)
    # WIDER_val 数据集测试
    # test_wider_Face(model_path, debug)
