Commit 885e3c20 authored by Michael Kösel's avatar Michael Kösel Committed by Francisco Massa
Browse files

Support for returning multiple targets (#700)

parent 8ce00704
......@@ -7,18 +7,45 @@ from PIL import Image
class Cityscapes(data.Dataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "gtCoarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
target_type='semantic')
img, smnt = dataset[0]
"""
def __init__(self, root, split='train', mode='gtFine', target_type='instance',
......@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
raise ValueError('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
' or split="val"')
if target_type not in ['instance', 'semantic', 'polygon', 'color']:
raise ValueError('Invalid value for "target_type"! Please use target_type="instance",'
' target_type="semantic", target_type="polygon" or target_type="color"')
if not isinstance(target_type, list):
self.target_type = [target_type]
if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"')
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
......@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
img_dir = os.path.join(self.images_dir, city)
target_dir = os.path.join(self.targets_dir, city)
for file_name in os.listdir(img_dir):
target_types = []
for t in self.target_type:
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
self._get_target_suffix(self.mode, self.target_type))
self._get_target_suffix(self.mode, t))
target_types.append(os.path.join(target_dir, target_name))
self.images.append(os.path.join(img_dir, file_name))
self.targets.append(os.path.join(target_dir, target_name))
self.targets.append(target_types)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a json object if target_type="polygon",
otherwise the image segmentation.
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image = Image.open(self.images[index]).convert('RGB')
if self.target_type == 'polygon':
target = self._load_json(self.targets[index])
targets = []
for i, t in enumerate(self.target_type):
if t == 'polygon':
target = self._load_json(self.targets[index][i])
else:
target = Image.open(self.targets[index])
target = Image.open(self.targets[index][i])
targets.append(target)
target = tuple(targets) if len(targets) > 1 else targets[0]
if self.transform:
image = self.transform(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment