Unverified Commit 62e185c7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve Coco implementation (#3417)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent a6f3f95a
...@@ -2,11 +2,11 @@ from .vision import VisionDataset ...@@ -2,11 +2,11 @@ from .vision import VisionDataset
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, List
class CocoCaptions(VisionDataset): class CocoDetection(VisionDataset):
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset. """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
Args: Args:
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
...@@ -17,77 +17,45 @@ class CocoCaptions(VisionDataset): ...@@ -17,77 +17,45 @@ class CocoCaptions(VisionDataset):
target and transforms it. target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version. and returns a transformed version.
Example:
.. code:: python
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
Output: ::
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
annFile: str, annFile: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
) -> None: ):
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform) super().__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.coco = COCO(annFile) self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys())) self.ids = list(sorted(self.coco.imgs.keys()))
def __getitem__(self, index: int) -> Tuple[Any, Any]: def _load_image(self, id: int) -> Image.Image:
""" path = self.coco.loadImgs(id)[0]["file_name"]
Args: return Image.open(os.path.join(self.root, path)).convert("RGB")
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)
target = [ann['caption'] for ann in anns]
path = coco.loadImgs(img_id)[0]['file_name'] def _load_target(self, id) -> List[Any]:
return self.coco.loadAnns(self.coco.getAnnIds(id))
img = Image.open(os.path.join(self.root, path)).convert('RGB') def __getitem__(self, index: int) -> Tuple[Any, Any]:
id = self.ids[index]
image = self._load_image(id)
target = self._load_target(id)
if self.transforms is not None: if self.transforms is not None:
img, target = self.transforms(img, target) image, target = self.transforms(image, target)
return img, target return image, target
def __len__(self) -> int: def __len__(self) -> int:
return len(self.ids) return len(self.ids)
class CocoDetection(VisionDataset): class CocoCaptions(CocoDetection):
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
Args: Args:
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
...@@ -98,41 +66,34 @@ class CocoDetection(VisionDataset): ...@@ -98,41 +66,34 @@ class CocoDetection(VisionDataset):
target and transforms it. target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version. and returns a transformed version.
"""
def __init__( Example:
self,
root: str,
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
def __getitem__(self, index: int) -> Tuple[Any, Any]: .. code:: python
"""
Args: import torchvision.datasets as dset
index (int): Index import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
Returns: print('Number of samples: ', len(cap))
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. img, target = cap[3] # load 4th sample
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name'] print("Image Size: ", img.size())
print(target)
img = Image.open(os.path.join(self.root, path)).convert('RGB') Output: ::
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
def __len__(self) -> int: """
return len(self.ids)
def _load_target(self, id) -> List[str]:
return [ann["caption"] for ann in super()._load_target(id)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment