import copy import os import json from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import io #import boto3 #import brainpp import torch import torch.utils.data import torchvision from pycocotools import mask as coco_mask from pycocotools.coco import COCO import transforms as T from torchvision.datasets.vision import VisionDataset class FilterAndRemapCocoCategories(object): def __init__(self, categories, remap=True): self.categories = categories self.remap = remap def __call__(self, image, target): anno = target["annotations"] anno = [obj for obj in anno if obj["category_id"] in self.categories] if not self.remap: target["annotations"] = anno return image, target anno = copy.deepcopy(anno) for obj in anno: obj["category_id"] = self.categories.index(obj["category_id"]) target["annotations"] = anno return image, target def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) mask = coco_mask.decode(rles) if len(mask.shape) < 3: mask = mask[..., None] mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) if masks: masks = torch.stack(masks, dim=0) else: masks = torch.zeros((0, height, width), dtype=torch.uint8) return masks class ConvertCocoPolysToMask(object): def __call__(self, image, target): w, h = image.size image_id = target["image_id"] image_id = torch.tensor([image_id]) anno = target["annotations"] anno = [obj for obj in anno if obj['iscrowd'] == 0] boxes = [obj["bbox"] for obj in anno] # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) classes = [obj["category_id"] for obj in anno] classes = torch.tensor(classes, dtype=torch.int64) keypoints = None if anno and "keypoints" in anno[0]: keypoints = [obj["keypoints"] for obj in anno] keypoints = torch.as_tensor(keypoints, dtype=torch.float32) num_keypoints = keypoints.shape[0] if num_keypoints: keypoints = keypoints.view(num_keypoints, -1, 3) keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) boxes = boxes[keep] classes = classes[keep] if keypoints is not None: keypoints = keypoints[keep] target = {} target["boxes"] = boxes target["labels"] = classes target["image_id"] = image_id if keypoints is not None: target["keypoints"] = keypoints # for conversion to coco api area = torch.tensor([obj["area"] for obj in anno]) iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) target["area"] = area target["iscrowd"] = iscrowd return image, target min_keypoints_per_image = 10 def _has_only_empty_bbox(anno): return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) def _count_visible_keypoints(anno): return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) def _has_valid_annotation(anno): # if it's empty, there is no annotation if len(anno) == 0: return False # if all boxes have close to zero area, there is no annotation if _has_only_empty_bbox(anno): return False # keypoints task have a slight different critera for considering # if an annotation is valid if "keypoints" not in anno[0]: return True # for keypoint detection tasks, only consider valid images those # containing at least min_keypoints_per_image if _count_visible_keypoints(anno) >= min_keypoints_per_image: return True return False def _coco_remove_images_without_annotations(dataset, cat_list=None): assert isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(dataset, CocoDetection) ids = [] empty = 0 for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) anno = dataset.coco.loadAnns(ann_ids) if cat_list: anno = [obj for obj in anno if obj["category_id"] in cat_list] if _has_valid_annotation(anno): ids.append(ds_idx) else: empty += 1 print("remove {} empty imgs without annos".format(empty)) dataset = torch.utils.data.Subset(dataset, ids) return dataset def convert_to_coco_api(ds): coco_ds = COCO() ann_id = 0 dataset = {'images': [], 'categories': [], 'annotations': []} categories = set() for img_idx in range(len(ds)): # find better way to get target # targets = ds.get_annotations(img_idx) img, targets = ds[img_idx] image_id = targets["image_id"].item() img_dict = {} img_dict['id'] = image_id img_dict['height'] = img.shape[-2] img_dict['width'] = img.shape[-1] dataset['images'].append(img_dict) bboxes = targets["boxes"] bboxes[:, 2:] -= bboxes[:, :2] bboxes = bboxes.tolist() labels = targets['labels'].tolist() areas = targets['area'].tolist() iscrowd = targets['iscrowd'].tolist() if 'masks' in targets: masks = targets['masks'] # make masks Fortran contiguous for coco_mask masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) if 'keypoints' in targets: keypoints = targets['keypoints'] keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() num_objs = len(bboxes) for i in range(num_objs): ann = {} ann['image_id'] = image_id ann['bbox'] = bboxes[i] ann['category_id'] = labels[i] categories.add(labels[i]) ann['area'] = areas[i] ann['iscrowd'] = iscrowd[i] ann['id'] = ann_id if 'keypoints' in targets: ann['keypoints'] = keypoints[i] ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) dataset['annotations'].append(ann) ann_id += 1 dataset['categories'] = [{'id': i} for i in sorted(categories)] coco_ds.dataset = dataset coco_ds.createIndex() return coco_ds def get_coco_api_from_dataset(dataset): for _ in range(10): if isinstance(dataset, torchvision.datasets.CocoDetection): break if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset if isinstance(dataset, torchvision.datasets.CocoDetection): return dataset.coco return convert_to_coco_api(dataset) class CocoDetection(VisionDataset): def __init__(self, root, annFile, transforms): super(CocoDetection, self).__init__(root, transforms=None, transform=None, target_transform=None) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) self._transforms = transforms with open(annFile, "r") as f: result = json.load(f) catids = [k['id'] for k in result['categories']] self.catid_inf = min(catids) ids_to_remove = [] ids = [] for img_id in self.ids: ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) anno = self.coco.loadAnns(ann_ids) if all( any(o <= 1 for o in obj["bbox"][2:]) for obj in anno if obj["iscrowd"] == 0 ): ids_to_remove.append(img_id) if _has_valid_annotation(anno): ids.append(img_id) print("remove {} illegal image".format(len(ids_to_remove))) self.ids = [img_id for img_id in ids if img_id not in ids_to_remove] def __getitem__(self, idx): coco = self.coco img_id = self.ids[idx] ann_ids = coco.getAnnIds(imgIds=img_id) target = coco.loadAnns(ann_ids) target = dict(image_id=img_id, annotations=target) path = coco.loadImgs(img_id)[0]['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') #img, target = self.remapper(img, target) if self._transforms is not None: img, target = self._transforms(img, target) target['labels'] = (target['labels'] - self.catid_inf + 1).long() return img, target def __len__(self): return len(self.ids) ''' class CocoDetection(torchvision.datasets.CocoDetection): def __init__(self, img_folder, ann_file, transforms): super(CocoDetection, self).__init__(img_folder, ann_file) self._transforms = transforms with open(ann_file, "r") as f: result = json.load(f) catids = [k['id'] for k in result['categories']] self.catid_inf = min(catids) self.num_classes = len(catids) print(self.num_classes) ids_to_remove = [] ids = [] for img_id in self.ids: ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) anno = self.coco.loadAnns(ann_ids) if all( any(o <= 1 for o in obj["bbox"][2:]) for obj in anno if obj["iscrowd"] == 0 ): ids_to_remove.append(img_id) if _has_valid_annotation(anno): ids.append(img_id) print("remove {} illegal image".format(len(ids_to_remove))) self.ids = [img_id for img_id in ids if img_id not in ids_to_remove] def __getitem__(self, idx): img, target = super(CocoDetection, self).__getitem__(idx) image_id = self.ids[idx] target = dict(image_id=image_id, annotations=target) if self._transforms is not None: img, target = self._transforms(img, target) target['labels'] = (target['labels'] - self.catid_inf + 1).long() return img, target class OssCocoDetection(VisionDataset): def __init__(self, root, annFile, transforms, host='http://oss.{}.brainpp.cn'.format(brainpp.current_vm.site)): super(OssCocoDetection, self).__init__(root, transforms=None, transform=None, target_transform=None) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) self.s3_client = boto3.client('s3', endpoint_url=host) self._transforms = transforms with open(annFile, "r") as f: result = json.load(f) catids = [k['id'] for k in result['categories']] self.catid_inf = min(catids) ids_to_remove = [] ids = [] for img_id in self.ids: ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) anno = self.coco.loadAnns(ann_ids) if all( any(o <= 1 for o in obj["bbox"][2:]) for obj in anno if obj["iscrowd"] == 0 ): ids_to_remove.append(img_id) if _has_valid_annotation(anno): ids.append(img_id) print("remove {} illegal image".format(len(ids_to_remove))) self.ids = [img_id for img_id in ids if img_id not in ids_to_remove] def __getitem__(self, idx): coco = self.coco img_id = self.ids[idx] ann_ids = coco.getAnnIds(imgIds=img_id) target = coco.loadAnns(ann_ids) target = dict(image_id=img_id, annotations=target) path = coco.loadImgs(img_id)[0]['file_name'] img_obj = self.s3_client.get_object( Bucket="generalDetection", Key=os.path.join(self.root, path)) img = Image.open(io.BytesIO(img_obj['Body'].read())).convert('RGB') #img, target = self.remapper(img, target) if self._transforms is not None: img, target = self._transforms(img, target) target['labels'] = (target['labels'] - self.catid_inf + 1).long() return img, target def __len__(self): return len(self.ids) def get_oss_coco(root, image_set, transforms, mode='instances'): t = [ConvertCocoPolysToMask()] if transforms is not None: t.append(transforms) transforms = T.Compose(t) datasets = list() for i_key, i_val in root.items(): dataset = OssCocoDetection( i_val['img_dir'], i_val['ann_file'], transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) datasets.append(dataset) dataset = datasets[0] # if len(datasets) == 1 else ConcatDataset(datasets) # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset ''' def get_coco(root, image_set, transforms, mode='instances'): t = [ConvertCocoPolysToMask()] if transforms is not None: t.append(transforms) transforms = T.Compose(t) img_folder = root[image_set]['img_dir'] ann_file = root[image_set]['ann_file'] dataset = CocoDetection(img_folder, ann_file, transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset def get_coco_kp(root, image_set, transforms): return get_coco(root, image_set, transforms, mode="person_keypoints")