cityscapes.py 10.2 KB
Newer Older
Michael Kösel's avatar
Michael Kösel committed
1
2
import json
import os
3
from collections import namedtuple
Philip Meier's avatar
Philip Meier committed
4
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
Michael Kösel's avatar
Michael Kösel committed
5

6
from .utils import extract_archive, verify_str_arg, iterable_to_str
7
from .vision import VisionDataset
Michael Kösel's avatar
Michael Kösel committed
8
9
10
from PIL import Image


11
class Cityscapes(VisionDataset):
Michael Kösel's avatar
Michael Kösel committed
12
    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
13

Michael Kösel's avatar
Michael Kösel committed
14
15
16
    Args:
        root (string): Root directory of dataset where directory ``leftImg8bit``
            and ``gtFine`` or ``gtCoarse`` are located.
Akshay Kulkarni's avatar
Akshay Kulkarni committed
17
        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
Michael Kösel's avatar
Michael Kösel committed
18
            otherwise ``train``, ``train_extra`` or ``val``
Akshay Kulkarni's avatar
Akshay Kulkarni committed
19
        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
20
21
        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
            or ``color``. Can also be a list to output a tuple with all specified target types.
Michael Kösel's avatar
Michael Kösel committed
22
23
24
25
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
26
27
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
28
29
30
31
32
33

    Examples:

        Get semantic segmentation target

        .. code-block:: python
34

35
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
36
37
38
39
40
41
42
                                 target_type='semantic')

            img, smnt = dataset[0]

        Get multiple targets

        .. code-block:: python
43

44
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
45
46
47
48
                                 target_type=['instance', 'color', 'polygon'])

            img, (inst, col, poly) = dataset[0]

49
        Validate on the "coarse" set
50
51

        .. code-block:: python
52

53
            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
54
55
56
                                 target_type='semantic')

            img, smnt = dataset[0]
Michael Kösel's avatar
Michael Kösel committed
57
58
    """

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    # Based on https://github.com/mcordts/cityscapesScripts
    CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])

    classes = [
        CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
        CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
        CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
        CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
        CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
        CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
        CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
        CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
        CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
        CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
        CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
        CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
        CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
        CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
        CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
        CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
        CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
        CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
        CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
        CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
        CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
        CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
        CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
        CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
        CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
        CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
        CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
        CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
        CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
        CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
    ]

Philip Meier's avatar
Philip Meier committed
101
102
103
104
105
106
107
108
109
110
    def __init__(
            self,
            root: str,
            split: str = "train",
            mode: str = "fine",
            target_type: Union[List[str], str] = "instance",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
111
        super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
112
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
Michael Kösel's avatar
Michael Kösel committed
113
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
114
        self.targets_dir = os.path.join(self.root, self.mode, split)
Michael Kösel's avatar
Michael Kösel committed
115
116
117
118
119
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

120
121
122
123
124
125
126
127
128
        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)
Michael Kösel's avatar
Michael Kösel committed
129

130
131
        if not isinstance(target_type, list):
            self.target_type = [target_type]
132
133
134
        [verify_str_arg(value, "target_type",
                        ("instance", "semantic", "polygon", "color"))
         for value in self.target_type]
Michael Kösel's avatar
Michael Kösel committed
135
136

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
137

138
            if split == 'train_extra':
139
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
140
            else:
141
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))
142
143

            if self.mode == 'gtFine':
144
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
145
            elif self.mode == 'gtCoarse':
146
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))
147
148

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
149
150
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
151
152
153
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')
Michael Kösel's avatar
Michael Kösel committed
154
155
156
157
158

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
159
160
161
162
163
                target_types = []
                for t in self.target_type:
                    target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                                 self._get_target_suffix(self.mode, t))
                    target_types.append(os.path.join(target_dir, target_name))
Michael Kösel's avatar
Michael Kösel committed
164
165

                self.images.append(os.path.join(img_dir, file_name))
166
                self.targets.append(target_types)
Michael Kösel's avatar
Michael Kösel committed
167

Philip Meier's avatar
Philip Meier committed
168
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Michael Kösel's avatar
Michael Kösel committed
169
170
171
172
        """
        Args:
            index (int): Index
        Returns:
173
174
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
Michael Kösel's avatar
Michael Kösel committed
175
176
177
178
        """

        image = Image.open(self.images[index]).convert('RGB')

Philip Meier's avatar
Philip Meier committed
179
        targets: Any = []
180
181
182
183
184
185
186
187
188
        for i, t in enumerate(self.target_type):
            if t == 'polygon':
                target = self._load_json(self.targets[index][i])
            else:
                target = Image.open(self.targets[index][i])

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]
Michael Kösel's avatar
Michael Kösel committed
189

190
191
        if self.transforms is not None:
            image, target = self.transforms(image, target)
Michael Kösel's avatar
Michael Kösel committed
192
193
194

        return image, target

Philip Meier's avatar
Philip Meier committed
195
    def __len__(self) -> int:
Michael Kösel's avatar
Michael Kösel committed
196
197
        return len(self.images)

Philip Meier's avatar
Philip Meier committed
198
    def extra_repr(self) -> str:
199
200
        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
        return '\n'.join(lines).format(**self.__dict__)
Michael Kösel's avatar
Michael Kösel committed
201

Philip Meier's avatar
Philip Meier committed
202
    def _load_json(self, path: str) -> Dict[str, Any]:
Michael Kösel's avatar
Michael Kösel committed
203
204
205
206
        with open(path, 'r') as file:
            data = json.load(file)
        return data

Philip Meier's avatar
Philip Meier committed
207
    def _get_target_suffix(self, mode: str, target_type: str) -> str:
Michael Kösel's avatar
Michael Kösel committed
208
209
210
211
212
213
214
215
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        else:
            return '{}_polygons.json'.format(mode)