Unverified Commit 47f80acc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cityscapes (#2525)

parent 15bd87f2
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import os import os
from collections import namedtuple from collections import namedtuple
import zipfile import zipfile
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from .utils import extract_archive, verify_str_arg, iterable_to_str from .utils import extract_archive, verify_str_arg, iterable_to_str
from .vision import VisionDataset from .vision import VisionDataset
...@@ -98,8 +99,16 @@ class Cityscapes(VisionDataset): ...@@ -98,8 +99,16 @@ class Cityscapes(VisionDataset):
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
] ]
def __init__(self, root, split='train', mode='fine', target_type='instance', def __init__(
transform=None, target_transform=None, transforms=None): self,
root: str,
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform) super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit', split) self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
...@@ -157,7 +166,7 @@ class Cityscapes(VisionDataset): ...@@ -157,7 +166,7 @@ class Cityscapes(VisionDataset):
self.images.append(os.path.join(img_dir, file_name)) self.images.append(os.path.join(img_dir, file_name))
self.targets.append(target_types) self.targets.append(target_types)
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -168,7 +177,7 @@ class Cityscapes(VisionDataset): ...@@ -168,7 +177,7 @@ class Cityscapes(VisionDataset):
image = Image.open(self.images[index]).convert('RGB') image = Image.open(self.images[index]).convert('RGB')
targets = [] targets: Any = []
for i, t in enumerate(self.target_type): for i, t in enumerate(self.target_type):
if t == 'polygon': if t == 'polygon':
target = self._load_json(self.targets[index][i]) target = self._load_json(self.targets[index][i])
...@@ -184,19 +193,19 @@ class Cityscapes(VisionDataset): ...@@ -184,19 +193,19 @@ class Cityscapes(VisionDataset):
return image, target return image, target
def __len__(self): def __len__(self) -> int:
return len(self.images) return len(self.images)
def extra_repr(self): def extra_repr(self) -> str:
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
return '\n'.join(lines).format(**self.__dict__) return '\n'.join(lines).format(**self.__dict__)
def _load_json(self, path): def _load_json(self, path: str) -> Dict[str, Any]:
with open(path, 'r') as file: with open(path, 'r') as file:
data = json.load(file) data = json.load(file)
return data return data
def _get_target_suffix(self, mode, target_type): def _get_target_suffix(self, mode: str, target_type: str) -> str:
if target_type == 'instance': if target_type == 'instance':
return '{}_instanceIds.png'.format(mode) return '{}_instanceIds.png'.format(mode)
elif target_type == 'semantic': elif target_type == 'semantic':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment