coco.py 4.42 KB
Newer Older
1
from .vision import VisionDataset
soumith's avatar
soumith committed
2
3
4
5
from PIL import Image
import os
import os.path

6

7
class CocoCaptions(VisionDataset):
8
    """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
9

10
11
12
13
14
15
16
    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
17
18
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    Example:

        .. code:: python

            import torchvision.datasets as dset
            import torchvision.transforms as transforms
            cap = dset.CocoCaptions(root = 'dir where images are',
                                    annFile = 'json annotation file',
                                    transform=transforms.ToTensor())

            print('Number of samples: ', len(cap))
            img, target = cap[3] # load 4th sample

            print("Image Size: ", img.size())
            print(target)

        Output: ::

            Number of samples: 82783
            Image Size: (3L, 427L, 640L)
            [u'A plane emitting smoke stream flying over a mountain.',
            u'A plane darts across a bright blue sky behind a mountain covered in snow',
            u'A plane leaves a contrail above the snowy mountain top.',
            u'A mountain that has a plane flying overheard in the distance.',
            u'A mountain view with a plume of smoke in the background']

    """
47

48
49
    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
        super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
soumith's avatar
soumith committed
50
51
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
52
        self.ids = list(sorted(self.coco.imgs.keys()))
soumith's avatar
soumith committed
53
54

    def __getitem__(self, index):
55
56
57
58
59
60
61
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
soumith's avatar
soumith committed
62
63
        coco = self.coco
        img_id = self.ids[index]
64
        ann_ids = coco.getAnnIds(imgIds=img_id)
soumith's avatar
soumith committed
65
66
67
68
69
70
71
        anns = coco.loadAnns(ann_ids)
        target = [ann['caption'] for ann in anns]

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

72
73
        if self.transforms is not None:
            img, target = self.transforms(img, target)
soumith's avatar
soumith committed
74
75
76
77
78
79

        return img, target

    def __len__(self):
        return len(self.ids)

80

81
class CocoDetection(VisionDataset):
82
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
83
84
85
86
87
88
89
90

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
91
92
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
93
    """
94

95
96
    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
        super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
soumith's avatar
soumith committed
97
98
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
99
        self.ids = list(sorted(self.coco.imgs.keys()))
soumith's avatar
soumith committed
100
101

    def __getitem__(self, index):
102
103
104
105
106
107
108
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
soumith's avatar
soumith committed
109
110
        coco = self.coco
        img_id = self.ids[index]
111
        ann_ids = coco.getAnnIds(imgIds=img_id)
soumith's avatar
soumith committed
112
113
114
115
116
        target = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')
117
118
        if self.transforms is not None:
            img, target = self.transforms(img, target)
soumith's avatar
soumith committed
119
120
121
122
123

        return img, target

    def __len__(self):
        return len(self.ids)