Unverified Commit 3245b10d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

coco (#2524)

parent bf7994b0
...@@ -2,6 +2,7 @@ from .vision import VisionDataset ...@@ -2,6 +2,7 @@ from .vision import VisionDataset
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, Optional, Tuple
class CocoCaptions(VisionDataset): class CocoCaptions(VisionDataset):
...@@ -45,13 +46,20 @@ class CocoCaptions(VisionDataset): ...@@ -45,13 +46,20 @@ class CocoCaptions(VisionDataset):
""" """
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): def __init__(
self,
root: str,
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform) super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.coco = COCO(annFile) self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys())) self.ids = list(sorted(self.coco.imgs.keys()))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -74,7 +82,7 @@ class CocoCaptions(VisionDataset): ...@@ -74,7 +82,7 @@ class CocoCaptions(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.ids) return len(self.ids)
...@@ -92,13 +100,20 @@ class CocoDetection(VisionDataset): ...@@ -92,13 +100,20 @@ class CocoDetection(VisionDataset):
and returns a transformed version. and returns a transformed version.
""" """
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): def __init__(
self,
root: str,
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(CocoDetection, self).__init__(root, transforms, transform, target_transform) super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.coco = COCO(annFile) self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys())) self.ids = list(sorted(self.coco.imgs.keys()))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -119,5 +134,5 @@ class CocoDetection(VisionDataset): ...@@ -119,5 +134,5 @@ class CocoDetection(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.ids) return len(self.ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment