import json
import logging
import os
import time

import cv2
import numpy as np
from pycocotools.cocoeval import COCOeval
import pycoco
import dataset

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("coco")


class OpenImages(dataset.Dataset):
    def __init__(
        self,
        data_path,
        image_list,
        name,
        use_cache=0,
        image_size=None,
        image_format="NHWC",
        pre_process=None,
        count=None,
        cache_dir=None,
        preprocessed_dir=None,
        use_label_map=False,
        threads=os.cpu_count(),
    ):
        super().__init__()
        self.image_size = image_size
        self.image_list = []
        self.label_list = []
        self.image_ids = []
        self.image_sizes = []
        self.count = count
        self.use_cache = use_cache
        self.data_path = data_path  # 对应 --dataset-path: /models/datasets/open-images-v6/validation
        self.pre_process = pre_process
        self.use_label_map = use_label_map

        # 图片实际存放目录（固定为 data_path 下的 data 子目录）
        self.images_dir = os.path.join(self.data_path, "data")

        if not cache_dir:
            cache_dir = os.getcwd()
        if pre_process:
            if preprocessed_dir:
                self.cache_dir = preprocessed_dir
            else:
                self.cache_dir = os.path.join(
                    cache_dir, "preprocessed", name, image_format
                )
        else:
            self.cache_dir = cache_dir
        self.need_transpose = True if image_format == "NCHW" else False
        not_found = 0  # 记录缺失的图片数量
        empty_80catageories = 0

        # 标注文件路径（固定为 data_path/annotations/openimages-mlperf.json）
        if image_list is None:
            #image_list = os.path.join(self.data_path, "annotations", "openimages-mlperf.json")
            image_list = os.path.join(self.data_path, "labels", "openimages-mlperf.json")
        self.annotation_file = image_list

        if self.use_label_map:
            label_map = {}
            with open(self.annotation_file) as fin:
                annotations = json.load(fin)
            for cnt, cat in enumerate(annotations["categories"]):
                label_map[cat["id"]] = cnt + 1

        os.makedirs(self.cache_dir, exist_ok=True)
        start = time.time()
        images = {}
        with open(image_list, "r") as f:
            openimages = json.load(f)

        # 解析标注文件中的图片信息
        for i in openimages["images"]:
            images[i["id"]] = {
                "file_name": i["file_name"],  # 图片文件名（如 fff3ce694bc02a09.jpg）
                "height": i["height"],
                "width": i["width"],
                "bbox": [],
                "category": [],
            }

        # 解析标注文件中的边界框信息
        for a in openimages["annotations"]:
            i = images.get(a["image_id"])
            if i is None:
                continue
            catagory_ids = (
                label_map[a.get("category_id")]
                if self.use_label_map
                else a.get("category_id")
            )
            i["category"].append(catagory_ids)
            i["bbox"].append(a.get("bbox"))

        # 遍历所有图片，检查是否存在并过滤
        for image_id, img in images.items():
            image_filename = img["file_name"]  # 仅文件名（不含路径）
            # 构建图片实际路径：data_path/data/图片文件名
            img_abs_path = os.path.join(self.images_dir, image_filename)

            # 检查图片是否存在
            if not os.path.exists(img_abs_path):
                not_found += 1
                log.warning(f"图片不存在，跳过: {img_abs_path}")
                continue

            # 过滤无标注的图片（保持原有逻辑）
            if len(img["category"]) == 0 and self.use_label_map:
                empty_80catageories += 1
                continue

            # 预处理缓存逻辑
            if not self.pre_process:
                # 非预处理模式：检查 .npy 文件
                npy_file = os.path.join(self.data_path, image_filename) + ".npy"
                if not os.path.exists(npy_file):
                    not_found += 1
                    log.warning(f"预处理文件不存在，跳过: {npy_file}")
                    continue
            else:
                # 预处理模式：生成缓存文件
                src = img_abs_path  # 原始图片路径
                # 缓存文件路径（保持与原始文件名结构一致）
                dst = os.path.join(self.cache_dir, image_filename)
                os.makedirs(os.path.dirname(dst), exist_ok=True)

                if not os.path.exists(dst + ".npy"):
                    # 读取并预处理图片
                    img_org = cv2.imread(src)
                    if img_org is None:
                        not_found += 1
                        log.warning(f"图片无法读取，跳过: {src}")
                        continue
                    processed = self.pre_process(
                        img_org,
                        need_transpose=self.need_transpose,
                        dims=self.image_size,
                    )
                    np.save(dst, processed)

            # 添加有效的图片信息
            self.image_ids.append(image_id)
            self.image_list.append(image_filename)  # 存储文件名（后续拼接路径用）
            self.image_sizes.append((img["height"], img["width"]))
            self.label_list.append((img["category"], img["bbox"]))

            # 限制数据集大小（如果指定）
            if self.count and len(self.image_list) >= self.count:
                break

        time_taken = time.time() - start

        # 检查是否有有效图片
        if not self.image_list:
            log.error("未找到任何有效图片，请检查以下路径：")
            log.error(f"图片目录: {self.images_dir}")
            log.error(f"标注文件: {self.annotation_file}")
            raise ValueError("no valid images found in image list")

        # 打印统计信息
        if not_found > 0:
            log.info(f"已跳过 {not_found} 张缺失或无法读取的图片")
        if empty_80catageories > 0:
            log.info(f"已过滤 {empty_80catageories} 张不包含目标类别的图片")

        log.info(
            f"成功加载 {len(self.image_list)} 张图片，缓存={use_cache}，预处理完成={pre_process is None}，耗时={time_taken:.1f}秒"
        )

        self.label_list = np.array(self.label_list, dtype=list)

    def get_item(self, nr):
        """Get image by number in the list."""
        # 从缓存目录加载预处理后的图片
        dst = os.path.join(self.cache_dir, self.image_list[nr])
        img = np.load(dst + ".npy")
        return img, self.label_list[nr]

    def get_item_loc(self, nr):
        """获取原始图片路径"""
        return os.path.join(self.images_dir, self.image_list[nr])


class PostProcessOpenImages:
    """
    Post processing for open images dataset. Annotations should
    be exported into coco format.
    """

    def __init__(self):
        self.results = []
        self.good = 0
        self.total = 0
        self.content_ids = []
        self.use_inv_map = False

    def add_results(self, results):
        self.results.extend(results)

    def __call__(
        self,
        results,
        ids,
        expected=None,
        result_dict=None,
    ):
        # results come as:
        # tensorflow, ssd-mobilenet:
        # num_detections,detection_boxes,detection_scores,detection_classes
        processed_results = []
        # batch size
        bs = len(results[0])
        for idx in range(0, bs):
            # keep the content_id from loadgen to handle content_id's without
            # results
            self.content_ids.append(ids[idx])
            processed_results.append([])
            detection_num = int(results[0][idx])
            detection_boxes = results[1][idx]
            detection_classes = results[3][idx]
            expected_classes = expected[idx][0]
            for detection in range(0, detection_num):
                detection_class = int(detection_classes[detection])
                if detection_class in expected_classes:
                    self.good += 1
                box = detection_boxes[detection]
                processed_results[idx].append(
                    [
                        float(ids[idx]),
                        box[0],
                        box[1],
                        box[2],
                        box[3],
                        results[2][idx][detection],
                        float(detection_class),
                    ]
                )
                self.total += 1
        return processed_results

    def start(self):
        self.results = []
        self.good = 0
        self.total = 0

    def finalize(self, result_dict, ds=None, output_dir=None):
        result_dict["good"] += self.good
        result_dict["total"] += self.total

        if self.use_inv_map:
            # for pytorch
            label_map = {}
            with open(ds.annotation_file) as fin:
                annotations = json.load(fin)
            for cnt, cat in enumerate(annotations["categories"]):
                label_map[cat["id"]] = cnt + 1
            inv_map = {v: k for k, v in label_map.items()}

        detections = []
        image_indices = []
        for batch in range(0, len(self.results)):
            image_indices.append(self.content_ids[batch])
            for idx in range(0, len(self.results[batch])):
                detection = self.results[batch][idx]
                # this is the index of the coco image
                image_idx = int(detection[0])
                if image_idx != self.content_ids[batch]:
                    log.error(
                        "image_idx missmatch, lg={} / result={}".format(
                            image_idx, self.content_ids[batch]
                        )
                    )
                # map the index to the coco image id
                detection[0] = ds.image_ids[image_idx]
                height, width = ds.image_sizes[image_idx]
                # box comes from model as: ymin, xmin, ymax, xmax
                ymin = detection[1] * height
                xmin = detection[2] * width
                ymax = detection[3] * height
                xmax = detection[4] * width
                # pycoco wants {imageID,x1,y1,w,h,score,class}
                detection[1] = xmin
                detection[2] = ymin
                detection[3] = xmax - xmin
                detection[4] = ymax - ymin
                if self.use_inv_map:
                    cat_id = inv_map.get(int(detection[6]), -1)
                    if cat_id == -1:
                        log.info(
                            "finalize can't map category {}".format(
                                int(detection[6]))
                        )
                    detection[6] = cat_id
                detections.append(np.array(detection))

        # map indices to coco image id's
        image_ids = [ds.image_ids[i] for i in image_indices]
        self.results = []
        cocoGt = pycoco.COCO(ds.annotation_file)
        cocoDt = cocoGt.loadRes(np.array(detections))
        cocoEval = COCOeval(cocoGt, cocoDt, iouType="bbox")
        cocoEval.params.imgIds = image_ids
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        result_dict["mAP"] = cocoEval.stats[0]


class PostProcessOpenImagesRetinanet(PostProcessOpenImages):
    """
    Post processing required by retinanet / pytorch & onnx
    """

    def __init__(self, use_inv_map, score_threshold,
                 height, width, dict_format=True):
        """
        Args:
            height (int): Height of the input image
            width (int): Width of the input image
            dict_format (bool): True if the model outputs a dictionary.
                        False otherwise. Defaults to True.
        """
        super().__init__()
        self.use_inv_map = use_inv_map
        self.score_threshold = score_threshold
        self.height = height
        self.width = width
        self.dict_format = dict_format

    def __call__(self, results, ids, expected=None, result_dict=None):
        if self.dict_format:
            # If the output of the model is in dictionary format. This happens
            # for the model retinanet-pytorch
            bboxes_ = [e["boxes"].cpu() for e in results]
            labels_ = [e["labels"].cpu() for e in results]
            scores_ = [e["scores"].cpu() for e in results]
            results = [bboxes_, labels_, scores_]
        else:
            bboxes_ = [results[0]]
            labels_ = [results[1]]
            scores_ = [results[2]]
            results = [bboxes_, labels_, scores_]

        processed_results = []
        content_ids = []
        # batch size
        bs = len(results[0])
        for idx in range(0, bs):
            content_ids.append(ids[idx])
            processed_results.append([])
            detection_boxes = results[0][idx]
            detection_classes = results[1][idx]
            expected_classes = expected[idx][0]
            scores = results[2][idx]
            for detection in range(0, len(scores)):
                if scores[detection] < self.score_threshold:
                    break
                detection_class = int(detection_classes[detection])
                if detection_class in expected_classes:
                    self.good += 1
                box = detection_boxes[detection]
                # box comes from model as: xmin, ymin, xmax, ymax
                # box comes with dimentions in the range of [0, height]
                # and [0, width] respectively. It is necesary to scale
                # them in the range [0, 1]
                processed_results[idx].append(
                    [
                        float(ids[idx]),
                        box[1] / self.height,
                        box[0] / self.width,
                        box[3] / self.height,
                        box[2] / self.width,
                        scores[detection],
                        float(detection_class),
                    ]
                )
                self.total += 1
        self.content_ids.extend(content_ids)
        return processed_results
