voc.py 9.93 KB
Newer Older
1
2
3
import os
import tarfile
import collections
4
from .vision import VisionDataset
5
import xml.etree.ElementTree as ET
6
from PIL import Image
7
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
from .utils import download_url, check_integrity, verify_str_arg
9
10
11
12
13
14

DATASET_YEAR_DICT = {
    '2012': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
        'filename': 'VOCtrainval_11-May-2012.tar',
        'md5': '6cd6e144f989b92b3379bac3b3de84fd',
Francisco Massa's avatar
Francisco Massa committed
15
        'base_dir': os.path.join('VOCdevkit', 'VOC2012')
16
17
18
19
20
    },
    '2011': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
        'filename': 'VOCtrainval_25-May-2011.tar',
        'md5': '6c3384ef61512963050cb5d687e5bf1e',
Francisco Massa's avatar
Francisco Massa committed
21
        'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011')
22
23
24
25
26
    },
    '2010': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
        'filename': 'VOCtrainval_03-May-2010.tar',
        'md5': 'da459979d0c395079b5c75ee67908abb',
Francisco Massa's avatar
Francisco Massa committed
27
        'base_dir': os.path.join('VOCdevkit', 'VOC2010')
28
29
30
31
32
    },
    '2009': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
        'filename': 'VOCtrainval_11-May-2009.tar',
        'md5': '59065e4b188729180974ef6572f6a212',
Francisco Massa's avatar
Francisco Massa committed
33
        'base_dir': os.path.join('VOCdevkit', 'VOC2009')
34
35
36
37
38
    },
    '2008': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
        'filename': 'VOCtrainval_11-May-2012.tar',
        'md5': '2629fa636546599198acfcfbfcf1904a',
Francisco Massa's avatar
Francisco Massa committed
39
        'base_dir': os.path.join('VOCdevkit', 'VOC2008')
40
41
42
43
44
    },
    '2007': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
        'filename': 'VOCtrainval_06-Nov-2007.tar',
        'md5': 'c52e279531787c972589f7e41ab4ae64',
Francisco Massa's avatar
Francisco Massa committed
45
        'base_dir': os.path.join('VOCdevkit', 'VOC2007')
46
47
48
49
50
51
    },
    '2007-test': {
        'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
        'filename': 'VOCtest_06-Nov-2007.tar',
        'md5': 'b6e924de25625d8de591ea690078ad9f',
        'base_dir': os.path.join('VOCdevkit', 'VOC2007')
52
53
54
55
    }
}


56
class VOCSegmentation(VisionDataset):
57
58
59
60
61
62
63
64
65
66
67
68
69
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.

    Args:
        root (string): Root directory of the VOC Dataset.
        year (string, optional): The dataset year, supports years 2007 to 2012.
        image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
70
71
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
72
73
    """

74
75
76
77
78
79
80
81
82
83
    def __init__(
            self,
            root: str,
            year: str = "2012",
            image_set: str = "train",
            download: bool = False,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ):
84
        super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
85
        self.year = year
86
87
        if year == "2007" and image_set == "test":
            year = "2007-test"
88
89
90
        self.url = DATASET_YEAR_DICT[year]['url']
        self.filename = DATASET_YEAR_DICT[year]['filename']
        self.md5 = DATASET_YEAR_DICT[year]['md5']
91
        valid_sets = ["train", "trainval", "val"]
92
        if year == "2007-test":
93
94
            valid_sets.append("test")
        self.image_set = verify_str_arg(image_set, "image_set", valid_sets)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        base_dir = DATASET_YEAR_DICT[year]['base_dir']
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, 'JPEGImages')
        mask_dir = os.path.join(voc_root, 'SegmentationClass')

        if download:
            download_extract(self.url, self.root, self.filename, self.md5)

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')

        split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))

118
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
119
120
121
122
123
124
125
126
127
128
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

129
130
        if self.transforms is not None:
            img, target = self.transforms(img, target)
131
132
133

        return img, target

134
    def __len__(self) -> int:
135
136
137
        return len(self.images)


138
class VOCDetection(VisionDataset):
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.

    Args:
        root (string): Root directory of the VOC Dataset.
        year (string, optional): The dataset year, supports years 2007 to 2012.
        image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
            (default: alphabetic indexing of VOC's 20 classes).
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, required): A function/transform that takes in the
            target and transforms it.
153
154
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
155
156
    """

157
158
159
160
161
162
163
164
165
166
    def __init__(
            self,
            root: str,
            year: str = "2012",
            image_set: str = "train",
            download: bool = False,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ):
167
        super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
168
        self.year = year
169
170
        if year == "2007" and image_set == "test":
            year = "2007-test"
171
172
173
        self.url = DATASET_YEAR_DICT[year]['url']
        self.filename = DATASET_YEAR_DICT[year]['filename']
        self.md5 = DATASET_YEAR_DICT[year]['md5']
174
        valid_sets = ["train", "trainval", "val"]
175
        if year == "2007-test":
176
177
            valid_sets.append("test")
        self.image_set = verify_str_arg(image_set, "image_set", valid_sets)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

        base_dir = DATASET_YEAR_DICT[year]['base_dir']
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, 'JPEGImages')
        annotation_dir = os.path.join(voc_root, 'Annotations')

        if download:
            download_extract(self.url, self.root, self.filename, self.md5)

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        splits_dir = os.path.join(voc_root, 'ImageSets/Main')

        split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))

202
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
203
204
205
206
207
208
209
210
211
212
213
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert('RGB')
        target = self.parse_voc_xml(
            ET.parse(self.annotations[index]).getroot())

214
215
        if self.transforms is not None:
            img, target = self.transforms(img, target)
216
217
218

        return img, target

219
    def __len__(self) -> int:
220
221
        return len(self.images)

222
223
    def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
        voc_dict: Dict[str, Any] = {}
224
225
        children = list(node)
        if children:
226
            def_dic: Dict[str, Any] = collections.defaultdict(list)
227
228
229
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
230
231
            if node.tag == 'annotation':
                def_dic['object'] = [def_dic['object']]
232
233
            voc_dict = {
                node.tag:
234
235
                    {ind: v[0] if len(v) == 1 else v
                     for ind, v in def_dic.items()}
236
237
238
239
240
241
242
243
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict


244
def download_extract(url: str, root: str, filename: str, md5: str) -> None:
245
246
247
    download_url(url, root, filename, md5)
    with tarfile.open(os.path.join(root, filename), "r") as tar:
        tar.extractall(path=root)