coco.py 4.08 KB
Newer Older
soumith's avatar
soumith committed
1
import os.path
2
3
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
soumith's avatar
soumith committed
4

5
6
7
8
from PIL import Image

from .vision import VisionDataset

9

10
11
class CocoDetection(VisionDataset):
    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
12

13
14
    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.

15
    Args:
16
        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
17
        annFile (string): Path to json annotation file.
anthony-cabacungan's avatar
anthony-cabacungan committed
18
        transform (callable, optional): A function/transform that takes in a PIL image
19
            and returns a transformed version. E.g, ``transforms.PILToTensor``
20
21
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
22
23
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
24
    """
25

Philip Meier's avatar
Philip Meier committed
26
    def __init__(
27
        self,
28
        root: Union[str, Path],
29
30
31
32
        annFile: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
33
    ) -> None:
34
        super().__init__(root, transforms, transform, target_transform)
soumith's avatar
soumith committed
35
        from pycocotools.coco import COCO
36

soumith's avatar
soumith committed
37
        self.coco = COCO(annFile)
38
        self.ids = list(sorted(self.coco.imgs.keys()))
soumith's avatar
soumith committed
39

40
41
42
    def _load_image(self, id: int) -> Image.Image:
        path = self.coco.loadImgs(id)[0]["file_name"]
        return Image.open(os.path.join(self.root, path)).convert("RGB")
soumith's avatar
soumith committed
43

44
    def _load_target(self, id: int) -> List[Any]:
45
        return self.coco.loadAnns(self.coco.getAnnIds(id))
soumith's avatar
soumith committed
46

47
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
48
49
50
51

        if not isinstance(index, int):
            raise ValueError(f"Index must be of type integer, got {type(index)} instead.")

52
53
54
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)
soumith's avatar
soumith committed
55

56
        if self.transforms is not None:
57
            image, target = self.transforms(image, target)
soumith's avatar
soumith committed
58

59
        return image, target
soumith's avatar
soumith committed
60

Philip Meier's avatar
Philip Meier committed
61
    def __len__(self) -> int:
soumith's avatar
soumith committed
62
63
        return len(self.ids)

64

65
66
class CocoCaptions(CocoDetection):
    """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
67

68
69
    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.

70
    Args:
71
        root (str or ``pathlib.Path``): Root directory where images are downloaded to.
72
        annFile (string): Path to json annotation file.
anthony-cabacungan's avatar
anthony-cabacungan committed
73
        transform (callable, optional): A function/transform that  takes in a PIL image
74
            and returns a transformed version. E.g, ``transforms.PILToTensor``
75
76
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
77
78
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
79

80
    Example:
soumith's avatar
soumith committed
81

82
83
84
85
86
87
        .. code:: python

            import torchvision.datasets as dset
            import torchvision.transforms as transforms
            cap = dset.CocoCaptions(root = 'dir where images are',
                                    annFile = 'json annotation file',
88
                                    transform=transforms.PILToTensor())
89

90
91
            print('Number of samples: ', len(cap))
            img, target = cap[3] # load 4th sample
soumith's avatar
soumith committed
92

93
94
            print("Image Size: ", img.size())
            print(target)
soumith's avatar
soumith committed
95

96
        Output: ::
soumith's avatar
soumith committed
97

98
99
100
101
102
103
104
            Number of samples: 82783
            Image Size: (3L, 427L, 640L)
            [u'A plane emitting smoke stream flying over a mountain.',
            u'A plane darts across a bright blue sky behind a mountain covered in snow',
            u'A plane leaves a contrail above the snowy mountain top.',
            u'A mountain that has a plane flying overheard in the distance.',
            u'A mountain view with a plume of smoke in the background']
soumith's avatar
soumith committed
105

106
107
    """

108
    def _load_target(self, id: int) -> List[str]:
109
        return [ann["caption"] for ann in super()._load_target(id)]