cityscapes.py 10.2 KB
Newer Older
Michael Kösel's avatar
Michael Kösel committed
1
2
import json
import os
3
from collections import namedtuple
4
import zipfile
Philip Meier's avatar
Philip Meier committed
5
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
Michael Kösel's avatar
Michael Kösel committed
6

7
from .utils import extract_archive, verify_str_arg, iterable_to_str
8
from .vision import VisionDataset
Michael Kösel's avatar
Michael Kösel committed
9
10
11
from PIL import Image


12
class Cityscapes(VisionDataset):
Michael Kösel's avatar
Michael Kösel committed
13
    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
14

Michael Kösel's avatar
Michael Kösel committed
15
16
17
    Args:
        root (string): Root directory of dataset where directory ``leftImg8bit``
            and ``gtFine`` or ``gtCoarse`` are located.
Akshay Kulkarni's avatar
Akshay Kulkarni committed
18
        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
Michael Kösel's avatar
Michael Kösel committed
19
            otherwise ``train``, ``train_extra`` or ``val``
Akshay Kulkarni's avatar
Akshay Kulkarni committed
20
        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
21
22
        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
            or ``color``. Can also be a list to output a tuple with all specified target types.
Michael Kösel's avatar
Michael Kösel committed
23
24
25
26
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
27
28
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
29
30
31
32
33
34

    Examples:

        Get semantic segmentation target

        .. code-block:: python
35

36
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
37
38
39
40
41
42
43
                                 target_type='semantic')

            img, smnt = dataset[0]

        Get multiple targets

        .. code-block:: python
44

45
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
46
47
48
49
                                 target_type=['instance', 'color', 'polygon'])

            img, (inst, col, poly) = dataset[0]

50
        Validate on the "coarse" set
51
52

        .. code-block:: python
53

54
            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
55
56
57
                                 target_type='semantic')

            img, smnt = dataset[0]
Michael Kösel's avatar
Michael Kösel committed
58
59
    """

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    # Based on https://github.com/mcordts/cityscapesScripts
    CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])

    classes = [
        CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
        CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
        CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
        CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
        CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
        CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
        CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
        CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
        CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
        CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
        CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
        CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
        CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
        CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
        CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
        CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
        CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
        CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
        CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
        CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
        CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
        CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
        CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
        CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
        CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
        CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
        CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
        CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
        CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
        CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
    ]

Philip Meier's avatar
Philip Meier committed
102
103
104
105
106
107
108
109
110
111
    def __init__(
            self,
            root: str,
            split: str = "train",
            mode: str = "fine",
            target_type: Union[List[str], str] = "instance",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
112
        super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
113
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
Michael Kösel's avatar
Michael Kösel committed
114
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
115
        self.targets_dir = os.path.join(self.root, self.mode, split)
Michael Kösel's avatar
Michael Kösel committed
116
117
118
119
120
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

121
122
123
124
125
126
127
128
129
        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)
Michael Kösel's avatar
Michael Kösel committed
130

131
132
        if not isinstance(target_type, list):
            self.target_type = [target_type]
133
134
135
        [verify_str_arg(value, "target_type",
                        ("instance", "semantic", "polygon", "color"))
         for value in self.target_type]
Michael Kösel's avatar
Michael Kösel committed
136
137

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
138

139
            if split == 'train_extra':
140
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
141
            else:
142
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))
143
144

            if self.mode == 'gtFine':
145
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
146
            elif self.mode == 'gtCoarse':
147
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))
148
149

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
150
151
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
152
153
154
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')
Michael Kösel's avatar
Michael Kösel committed
155
156
157
158
159

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
160
161
162
163
164
                target_types = []
                for t in self.target_type:
                    target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                                 self._get_target_suffix(self.mode, t))
                    target_types.append(os.path.join(target_dir, target_name))
Michael Kösel's avatar
Michael Kösel committed
165
166

                self.images.append(os.path.join(img_dir, file_name))
167
                self.targets.append(target_types)
Michael Kösel's avatar
Michael Kösel committed
168

Philip Meier's avatar
Philip Meier committed
169
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Michael Kösel's avatar
Michael Kösel committed
170
171
172
173
        """
        Args:
            index (int): Index
        Returns:
174
175
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
Michael Kösel's avatar
Michael Kösel committed
176
177
178
179
        """

        image = Image.open(self.images[index]).convert('RGB')

Philip Meier's avatar
Philip Meier committed
180
        targets: Any = []
181
182
183
184
185
186
187
188
189
        for i, t in enumerate(self.target_type):
            if t == 'polygon':
                target = self._load_json(self.targets[index][i])
            else:
                target = Image.open(self.targets[index][i])

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]
Michael Kösel's avatar
Michael Kösel committed
190

191
192
        if self.transforms is not None:
            image, target = self.transforms(image, target)
Michael Kösel's avatar
Michael Kösel committed
193
194
195

        return image, target

Philip Meier's avatar
Philip Meier committed
196
    def __len__(self) -> int:
Michael Kösel's avatar
Michael Kösel committed
197
198
        return len(self.images)

Philip Meier's avatar
Philip Meier committed
199
    def extra_repr(self) -> str:
200
201
        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
        return '\n'.join(lines).format(**self.__dict__)
Michael Kösel's avatar
Michael Kösel committed
202

Philip Meier's avatar
Philip Meier committed
203
    def _load_json(self, path: str) -> Dict[str, Any]:
Michael Kösel's avatar
Michael Kösel committed
204
205
206
207
        with open(path, 'r') as file:
            data = json.load(file)
        return data

Philip Meier's avatar
Philip Meier committed
208
    def _get_target_suffix(self, mode: str, target_type: str) -> str:
Michael Kösel's avatar
Michael Kösel committed
209
210
211
212
213
214
215
216
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        else:
            return '{}_polygons.json'.format(mode)