Unverified Commit 47f80acc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cityscapes (#2525)

parent 15bd87f2
......@@ -2,6 +2,7 @@ import json
import os
from collections import namedtuple
import zipfile
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from .utils import extract_archive, verify_str_arg, iterable_to_str
from .vision import VisionDataset
......@@ -98,8 +99,16 @@ class Cityscapes(VisionDataset):
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
]
def __init__(self, root, split='train', mode='fine', target_type='instance',
transform=None, target_transform=None, transforms=None):
def __init__(
self,
root: str,
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
......@@ -157,7 +166,7 @@ class Cityscapes(VisionDataset):
self.images.append(os.path.join(img_dir, file_name))
self.targets.append(target_types)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -168,7 +177,7 @@ class Cityscapes(VisionDataset):
image = Image.open(self.images[index]).convert('RGB')
targets = []
targets: Any = []
for i, t in enumerate(self.target_type):
if t == 'polygon':
target = self._load_json(self.targets[index][i])
......@@ -184,19 +193,19 @@ class Cityscapes(VisionDataset):
return image, target
def __len__(self):
def __len__(self) -> int:
return len(self.images)
def extra_repr(self):
def extra_repr(self) -> str:
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
return '\n'.join(lines).format(**self.__dict__)
def _load_json(self, path):
def _load_json(self, path: str) -> Dict[str, Any]:
with open(path, 'r') as file:
data = json.load(file)
return data
def _get_target_suffix(self, mode, target_type):
def _get_target_suffix(self, mode: str, target_type: str) -> str:
if target_type == 'instance':
return '{}_instanceIds.png'.format(mode)
elif target_type == 'semantic':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment