Commit 885e3c20 authored by Michael Kösel's avatar Michael Kösel Committed by Francisco Massa
Browse files

Support for returning multiple targets (#700)

parent 8ce00704
...@@ -7,18 +7,45 @@ from PIL import Image ...@@ -7,18 +7,45 @@ from PIL import Image
class Cityscapes(data.Dataset): class Cityscapes(data.Dataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory ``leftImg8bit`` root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located. and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine" split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
otherwise ``train``, ``train_extra`` or ``val`` otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse`` mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color`` or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "gtCoarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
target_type='semantic')
img, smnt = dataset[0]
""" """
def __init__(self, root, split='train', mode='gtFine', target_type='instance', def __init__(self, root, split='train', mode='gtFine', target_type='instance',
...@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset): ...@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
raise ValueError('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"' raise ValueError('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
' or split="val"') ' or split="val"')
if target_type not in ['instance', 'semantic', 'polygon', 'color']: if not isinstance(target_type, list):
raise ValueError('Invalid value for "target_type"! Please use target_type="instance",' self.target_type = [target_type]
' target_type="semantic", target_type="polygon" or target_type="color"')
if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"')
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
...@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset): ...@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
img_dir = os.path.join(self.images_dir, city) img_dir = os.path.join(self.images_dir, city)
target_dir = os.path.join(self.targets_dir, city) target_dir = os.path.join(self.targets_dir, city)
for file_name in os.listdir(img_dir): for file_name in os.listdir(img_dir):
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], target_types = []
self._get_target_suffix(self.mode, self.target_type)) for t in self.target_type:
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
self._get_target_suffix(self.mode, t))
target_types.append(os.path.join(target_dir, target_name))
self.images.append(os.path.join(img_dir, file_name)) self.images.append(os.path.join(img_dir, file_name))
self.targets.append(os.path.join(target_dir, target_name)) self.targets.append(target_types)
def __getitem__(self, index): def __getitem__(self, index):
""" """
Args: Args:
index (int): Index index (int): Index
Returns: Returns:
tuple: (image, target) where target is a json object if target_type="polygon", tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
otherwise the image segmentation. than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
""" """
image = Image.open(self.images[index]).convert('RGB') image = Image.open(self.images[index]).convert('RGB')
if self.target_type == 'polygon': targets = []
target = self._load_json(self.targets[index]) for i, t in enumerate(self.target_type):
else: if t == 'polygon':
target = Image.open(self.targets[index]) target = self._load_json(self.targets[index][i])
else:
target = Image.open(self.targets[index][i])
targets.append(target)
target = tuple(targets) if len(targets) > 1 else targets[0]
if self.transform: if self.transform:
image = self.transform(image) image = self.transform(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment