cityscapes.py 10.1 KB
Newer Older
Michael Kösel's avatar
Michael Kösel committed
1
2
import json
import os
3
from collections import namedtuple
limm's avatar
limm committed
4
5
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Michael Kösel's avatar
Michael Kösel committed
6
7
8

from PIL import Image

limm's avatar
limm committed
9
10
11
from .utils import extract_archive, iterable_to_str, verify_str_arg
from .vision import VisionDataset

Michael Kösel's avatar
Michael Kösel committed
12

13
class Cityscapes(VisionDataset):
Michael Kösel's avatar
Michael Kösel committed
14
    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
15

Michael Kösel's avatar
Michael Kösel committed
16
    Args:
limm's avatar
limm committed
17
        root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
Michael Kösel's avatar
Michael Kösel committed
18
            and ``gtFine`` or ``gtCoarse`` are located.
Akshay Kulkarni's avatar
Akshay Kulkarni committed
19
        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
Michael Kösel's avatar
Michael Kösel committed
20
            otherwise ``train``, ``train_extra`` or ``val``
Akshay Kulkarni's avatar
Akshay Kulkarni committed
21
        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
22
23
        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
            or ``color``. Can also be a list to output a tuple with all specified target types.
Michael Kösel's avatar
Michael Kösel committed
24
25
26
27
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
28
29
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
30
31
32
33
34
35

    Examples:

        Get semantic segmentation target

        .. code-block:: python
36

37
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
38
39
40
41
42
43
44
                                 target_type='semantic')

            img, smnt = dataset[0]

        Get multiple targets

        .. code-block:: python
45

46
            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
47
48
49
50
                                 target_type=['instance', 'color', 'polygon'])

            img, (inst, col, poly) = dataset[0]

51
        Validate on the "coarse" set
52
53

        .. code-block:: python
54

55
            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
56
57
58
                                 target_type='semantic')

            img, smnt = dataset[0]
Michael Kösel's avatar
Michael Kösel committed
59
60
    """

61
    # Based on https://github.com/mcordts/cityscapesScripts
limm's avatar
limm committed
62
63
64
65
    CityscapesClass = namedtuple(
        "CityscapesClass",
        ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
    )
66
67

    classes = [
limm's avatar
limm committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
        CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
        CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
        CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
        CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
        CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
        CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
        CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
        CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
        CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
        CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
        CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
        CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
        CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
        CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
        CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
        CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
        CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
        CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
        CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
        CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
        CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
        CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
        CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
        CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
        CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
        CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
        CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
        CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
        CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
        CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
103
104
    ]

Philip Meier's avatar
Philip Meier committed
105
    def __init__(
limm's avatar
limm committed
106
107
108
109
110
111
112
113
        self,
        root: Union[str, Path],
        split: str = "train",
        mode: str = "fine",
        target_type: Union[List[str], str] = "instance",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
Philip Meier's avatar
Philip Meier committed
114
    ) -> None:
limm's avatar
limm committed
115
116
117
        super().__init__(root, transforms, transform, target_transform)
        self.mode = "gtFine" if mode == "fine" else "gtCoarse"
        self.images_dir = os.path.join(self.root, "leftImg8bit", split)
118
        self.targets_dir = os.path.join(self.root, self.mode, split)
Michael Kösel's avatar
Michael Kösel committed
119
120
121
122
123
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

124
125
126
127
128
        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
limm's avatar
limm committed
129
        msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
130
131
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)
Michael Kösel's avatar
Michael Kösel committed
132

133
134
        if not isinstance(target_type, list):
            self.target_type = [target_type]
limm's avatar
limm committed
135
136
137
138
        [
            verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
            for value in self.target_type
        ]
Michael Kösel's avatar
Michael Kösel committed
139
140

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
141

limm's avatar
limm committed
142
143
            if split == "train_extra":
                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
144
            else:
limm's avatar
limm committed
145
                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
146

limm's avatar
limm committed
147
148
149
150
            if self.mode == "gtFine":
                target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
            elif self.mode == "gtCoarse":
                target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
151
152

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
153
154
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
155
            else:
limm's avatar
limm committed
156
157
158
159
                raise RuntimeError(
                    "Dataset not found or incomplete. Please make sure all required folders for the"
                    ' specified "split" and "mode" are inside the "root" directory'
                )
Michael Kösel's avatar
Michael Kösel committed
160
161
162
163
164

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
165
166
                target_types = []
                for t in self.target_type:
limm's avatar
limm committed
167
168
169
                    target_name = "{}_{}".format(
                        file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
                    )
170
                    target_types.append(os.path.join(target_dir, target_name))
Michael Kösel's avatar
Michael Kösel committed
171
172

                self.images.append(os.path.join(img_dir, file_name))
173
                self.targets.append(target_types)
Michael Kösel's avatar
Michael Kösel committed
174

Philip Meier's avatar
Philip Meier committed
175
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Michael Kösel's avatar
Michael Kösel committed
176
177
178
179
        """
        Args:
            index (int): Index
        Returns:
180
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
limm's avatar
limm committed
181
            than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
Michael Kösel's avatar
Michael Kösel committed
182
183
        """

limm's avatar
limm committed
184
        image = Image.open(self.images[index]).convert("RGB")
Michael Kösel's avatar
Michael Kösel committed
185

Philip Meier's avatar
Philip Meier committed
186
        targets: Any = []
187
        for i, t in enumerate(self.target_type):
limm's avatar
limm committed
188
            if t == "polygon":
189
190
                target = self._load_json(self.targets[index][i])
            else:
limm's avatar
limm committed
191
                target = Image.open(self.targets[index][i])  # type: ignore[assignment]
192
193
194
195

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]
Michael Kösel's avatar
Michael Kösel committed
196

197
198
        if self.transforms is not None:
            image, target = self.transforms(image, target)
Michael Kösel's avatar
Michael Kösel committed
199
200
201

        return image, target

Philip Meier's avatar
Philip Meier committed
202
    def __len__(self) -> int:
Michael Kösel's avatar
Michael Kösel committed
203
204
        return len(self.images)

Philip Meier's avatar
Philip Meier committed
205
    def extra_repr(self) -> str:
206
        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
limm's avatar
limm committed
207
        return "\n".join(lines).format(**self.__dict__)
Michael Kösel's avatar
Michael Kösel committed
208

Philip Meier's avatar
Philip Meier committed
209
    def _load_json(self, path: str) -> Dict[str, Any]:
limm's avatar
limm committed
210
        with open(path) as file:
Michael Kösel's avatar
Michael Kösel committed
211
212
213
            data = json.load(file)
        return data

Philip Meier's avatar
Philip Meier committed
214
    def _get_target_suffix(self, mode: str, target_type: str) -> str:
limm's avatar
limm committed
215
216
217
218
219
220
        if target_type == "instance":
            return f"{mode}_instanceIds.png"
        elif target_type == "semantic":
            return f"{mode}_labelIds.png"
        elif target_type == "color":
            return f"{mode}_color.png"
Michael Kösel's avatar
Michael Kösel committed
221
        else:
limm's avatar
limm committed
222
            return f"{mode}_polygons.json"