cityscapes.py 10 KB
Newer Older
Michael Kösel's avatar
Michael Kösel committed
1
2
import json
import os
3
from collections import namedtuple
Philip Meier's avatar
Philip Meier committed
4
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
Michael Kösel's avatar
Michael Kösel committed
5

6
7
from PIL import Image

8
from .utils import extract_archive, verify_str_arg, iterable_to_str
9
from .vision import VisionDataset
Michael Kösel's avatar
Michael Kösel committed
10
11


12
class Cityscapes(VisionDataset):
Michael Kösel's avatar
Michael Kösel committed
13
    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
14

Michael Kösel's avatar
Michael Kösel committed
15
16
17
    Args:
        root (string): Root directory of dataset where directory ``leftImg8bit``
            and ``gtFine`` or ``gtCoarse`` are located.
Akshay Kulkarni's avatar
Akshay Kulkarni committed
18
        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
Michael Kösel's avatar
Michael Kösel committed
19
            otherwise ``train``, ``train_extra`` or ``val``
Akshay Kulkarni's avatar
Akshay Kulkarni committed
20
        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
21
22
        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
            or ``color``. Can also be a list to output a tuple with all specified target types.
Michael Kösel's avatar
Michael Kösel committed
23
24
25
26
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
27
28
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
29
30
31
32
33
34

    Examples:

        Get semantic segmentation target

        .. code-block:: python
35

36
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
37
38
39
40
41
42
43
                                 target_type='semantic')

            img, smnt = dataset[0]

        Get multiple targets

        .. code-block:: python
44

45
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
46
47
48
49
                                 target_type=['instance', 'color', 'polygon'])

            img, (inst, col, poly) = dataset[0]

50
        Validate on the "coarse" set
51
52

        .. code-block:: python
53

54
            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
55
56
57
                                 target_type='semantic')

            img, smnt = dataset[0]
Michael Kösel's avatar
Michael Kösel committed
58
59
    """

60
    # Based on https://github.com/mcordts/cityscapesScripts
61
62
63
64
    CityscapesClass = namedtuple(
        "CityscapesClass",
        ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
    )
65
66

    classes = [
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
        CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
        CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
        CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
        CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
        CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
        CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
        CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
        CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
        CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
        CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
        CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
        CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
        CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
        CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
        CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
        CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
        CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
        CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
        CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
        CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
        CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
        CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
        CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
        CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
        CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
        CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
        CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
        CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
        CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
102
103
    ]

Philip Meier's avatar
Philip Meier committed
104
    def __init__(
105
106
107
108
109
110
111
112
        self,
        root: str,
        split: str = "train",
        mode: str = "fine",
        target_type: Union[List[str], str] = "instance",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
Philip Meier's avatar
Philip Meier committed
113
    ) -> None:
114
        super().__init__(root, transforms, transform, target_transform)
115
116
        self.mode = "gtFine" if mode == "fine" else "gtCoarse"
        self.images_dir = os.path.join(self.root, "leftImg8bit", split)
117
        self.targets_dir = os.path.join(self.root, self.mode, split)
Michael Kösel's avatar
Michael Kösel committed
118
119
120
121
122
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

123
124
125
126
127
        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
128
        msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
129
130
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)
Michael Kösel's avatar
Michael Kösel committed
131

132
133
        if not isinstance(target_type, list):
            self.target_type = [target_type]
134
135
136
137
        [
            verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
            for value in self.target_type
        ]
Michael Kösel's avatar
Michael Kösel committed
138
139

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
140

141
            if split == "train_extra":
142
                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
143
            else:
144
                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
145

146
            if self.mode == "gtFine":
147
                target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
148
            elif self.mode == "gtCoarse":
149
                target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
150
151

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
152
153
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
154
            else:
155
156
157
158
                raise RuntimeError(
                    "Dataset not found or incomplete. Please make sure all required folders for the"
                    ' specified "split" and "mode" are inside the "root" directory'
                )
Michael Kösel's avatar
Michael Kösel committed
159
160
161
162
163

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
164
165
                target_types = []
                for t in self.target_type:
166
167
168
                    target_name = "{}_{}".format(
                        file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
                    )
169
                    target_types.append(os.path.join(target_dir, target_name))
Michael Kösel's avatar
Michael Kösel committed
170
171

                self.images.append(os.path.join(img_dir, file_name))
172
                self.targets.append(target_types)
Michael Kösel's avatar
Michael Kösel committed
173

Philip Meier's avatar
Philip Meier committed
174
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Michael Kösel's avatar
Michael Kösel committed
175
176
177
178
        """
        Args:
            index (int): Index
        Returns:
179
180
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
Michael Kösel's avatar
Michael Kösel committed
181
182
        """

183
        image = Image.open(self.images[index]).convert("RGB")
Michael Kösel's avatar
Michael Kösel committed
184

Philip Meier's avatar
Philip Meier committed
185
        targets: Any = []
186
        for i, t in enumerate(self.target_type):
187
            if t == "polygon":
188
189
190
191
192
193
194
                target = self._load_json(self.targets[index][i])
            else:
                target = Image.open(self.targets[index][i])

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]
Michael Kösel's avatar
Michael Kösel committed
195

196
197
        if self.transforms is not None:
            image, target = self.transforms(image, target)
Michael Kösel's avatar
Michael Kösel committed
198
199
200

        return image, target

Philip Meier's avatar
Philip Meier committed
201
    def __len__(self) -> int:
Michael Kösel's avatar
Michael Kösel committed
202
203
        return len(self.images)

Philip Meier's avatar
Philip Meier committed
204
    def extra_repr(self) -> str:
205
        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
206
        return "\n".join(lines).format(**self.__dict__)
Michael Kösel's avatar
Michael Kösel committed
207

Philip Meier's avatar
Philip Meier committed
208
    def _load_json(self, path: str) -> Dict[str, Any]:
209
        with open(path) as file:
Michael Kösel's avatar
Michael Kösel committed
210
211
212
            data = json.load(file)
        return data

Philip Meier's avatar
Philip Meier committed
213
    def _get_target_suffix(self, mode: str, target_type: str) -> str:
214
        if target_type == "instance":
215
            return f"{mode}_instanceIds.png"
216
        elif target_type == "semantic":
217
            return f"{mode}_labelIds.png"
218
        elif target_type == "color":
219
            return f"{mode}_color.png"
Michael Kösel's avatar
Michael Kösel committed
220
        else:
221
            return f"{mode}_polygons.json"