import torch.utils.data as data from PIL import Image import os import os.path class CocoCaptions(data.Dataset): def __init__(self, root, annFile, transform=None, target_transform=None): from pycocotools.coco import COCO self.root = root self.coco = COCO(annFile) self.ids = list(self.coco.imgs.keys()) self.transform = transform self.target_transform = target_transform def __getitem__(self, index): coco = self.coco img_id = self.ids[index] ann_ids = coco.getAnnIds(imgIds=img_id) anns = coco.loadAnns(ann_ids) target = [ann['caption'] for ann in anns] path = coco.loadImgs(img_id)[0]['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids) class CocoDetection(data.Dataset): def __init__(self, root, annFile, transform=None, target_transform=None): from pycocotools.coco import COCO self.root = root self.coco = COCO(annFile) self.ids = list(self.coco.imgs.keys()) self.transform = transform self.target_transform = target_transform def __getitem__(self, index): coco = self.coco img_id = self.ids[index] ann_ids = coco.getAnnIds(imgIds=img_id) target = coco.loadAnns(ann_ids) path = coco.loadImgs(img_id)[0]['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids)