coco.py 4.83 KB
Newer Older
1
from .vision import VisionDataset
soumith's avatar
soumith committed
2
3
4
from PIL import Image
import os
import os.path
Philip Meier's avatar
Philip Meier committed
5
from typing import Any, Callable, Optional, Tuple
soumith's avatar
soumith committed
6

7

8
class CocoCaptions(VisionDataset):
9
    """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
10

11
12
13
14
15
16
17
    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
18
19
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    Example:

        .. code:: python

            import torchvision.datasets as dset
            import torchvision.transforms as transforms
            cap = dset.CocoCaptions(root = 'dir where images are',
                                    annFile = 'json annotation file',
                                    transform=transforms.ToTensor())

            print('Number of samples: ', len(cap))
            img, target = cap[3] # load 4th sample

            print("Image Size: ", img.size())
            print(target)

        Output: ::

            Number of samples: 82783
            Image Size: (3L, 427L, 640L)
            [u'A plane emitting smoke stream flying over a mountain.',
            u'A plane darts across a bright blue sky behind a mountain covered in snow',
            u'A plane leaves a contrail above the snowy mountain top.',
            u'A mountain that has a plane flying overheard in the distance.',
            u'A mountain view with a plume of smoke in the background']

    """
48

Philip Meier's avatar
Philip Meier committed
49
50
51
52
53
54
55
56
    def __init__(
            self,
            root: str,
            annFile: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
57
        super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
soumith's avatar
soumith committed
58
59
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
60
        self.ids = list(sorted(self.coco.imgs.keys()))
soumith's avatar
soumith committed
61

Philip Meier's avatar
Philip Meier committed
62
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
63
64
65
66
67
68
69
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
soumith's avatar
soumith committed
70
71
        coco = self.coco
        img_id = self.ids[index]
72
        ann_ids = coco.getAnnIds(imgIds=img_id)
soumith's avatar
soumith committed
73
74
75
76
77
78
79
        anns = coco.loadAnns(ann_ids)
        target = [ann['caption'] for ann in anns]

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

80
81
        if self.transforms is not None:
            img, target = self.transforms(img, target)
soumith's avatar
soumith committed
82
83
84

        return img, target

Philip Meier's avatar
Philip Meier committed
85
    def __len__(self) -> int:
soumith's avatar
soumith committed
86
87
        return len(self.ids)

88

89
class CocoDetection(VisionDataset):
90
    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
91
92
93
94
95
96
97
98

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
99
100
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
101
    """
102

Philip Meier's avatar
Philip Meier committed
103
104
105
106
107
108
109
110
    def __init__(
            self,
            root: str,
            annFile: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
111
        super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
soumith's avatar
soumith committed
112
113
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
114
        self.ids = list(sorted(self.coco.imgs.keys()))
soumith's avatar
soumith committed
115

Philip Meier's avatar
Philip Meier committed
116
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
117
118
119
120
121
122
123
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
soumith's avatar
soumith committed
124
125
        coco = self.coco
        img_id = self.ids[index]
126
        ann_ids = coco.getAnnIds(imgIds=img_id)
soumith's avatar
soumith committed
127
128
129
130
131
        target = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')
132
133
        if self.transforms is not None:
            img, target = self.transforms(img, target)
soumith's avatar
soumith committed
134
135
136

        return img, target

Philip Meier's avatar
Philip Meier committed
137
    def __len__(self) -> int:
soumith's avatar
soumith committed
138
        return len(self.ids)