import copy import torch import torch.utils.data import torchvision from pycocotools import mask as coco_mask from pycocotools.coco import COCO from tqdm import tqdm class FilterAndRemapCocoCategories(object): def __init__(self, categories, remap=True): self.categories = categories self.remap = remap def __call__(self, image, target): anno = target["annotations"] anno = [obj for obj in anno if obj["category_id"] in self.categories] if not self.remap: target["annotations"] = anno return image, target anno = copy.deepcopy(anno) for obj in anno: obj["category_id"] = self.categories.index(obj["category_id"]) target["annotations"] = anno return image, target def convert_to_coco_api(ds): coco_ds = COCO() ann_id = 0 dataset = {"images": [], "categories": [], "annotations": []} categories = set() for img_idx in tqdm(range(len(ds))): # find better way to get target # targets = ds.get_annotations(img_idx) img, targets = ds[img_idx] image_id = targets["image_id"].item() img_dict = {} img_dict["id"] = image_id img_dict["height"] = img.shape[-2] img_dict["width"] = img.shape[-1] dataset["images"].append(img_dict) bboxes = targets["boxes"] bboxes[:, 2:] -= bboxes[:, :2] bboxes = bboxes.tolist() labels = targets["labels"].tolist() areas = targets["area"].tolist() iscrowd = targets["iscrowd"].tolist() if "masks" in targets: masks = targets["masks"] # make masks Fortran contiguous for coco_mask masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) if "keypoints" in targets: keypoints = targets["keypoints"] keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() num_objs = len(bboxes) for i in range(num_objs): ann = {} ann["image_id"] = image_id ann["bbox"] = bboxes[i] ann["category_id"] = labels[i] categories.add(labels[i]) ann["area"] = areas[i] ann["iscrowd"] = iscrowd[i] ann["id"] = ann_id if "masks" in targets: ann["segmentation"] = coco_mask.encode(masks[i].numpy()) if "keypoints" in targets: ann["keypoints"] = keypoints[i] ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) dataset["annotations"].append(ann) ann_id += 1 dataset["categories"] = [{"id": i} for i in sorted(categories)] coco_ds.dataset = dataset coco_ds.createIndex() return coco_ds def get_coco_api_from_dataset(dataset): for _ in range(10): if isinstance(dataset, torchvision.datasets.CocoDetection): break if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset if isinstance(dataset, torchvision.datasets.CocoDetection): return dataset.coco return convert_to_coco_api(dataset)