"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4bae76e4539c30f68fa4e39c4e492a2155cf81d0"
Commit 4ffc28c9 authored by Michael Kösel's avatar Michael Kösel Committed by Francisco Massa
Browse files

Small API cleanup to Cityscapes (#725)

parent ef5b3dad
...@@ -26,7 +26,7 @@ class Cityscapes(data.Dataset): ...@@ -26,7 +26,7 @@ class Cityscapes(data.Dataset):
Get semantic segmentation target Get semantic segmentation target
.. code-block:: python .. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine', dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic') target_type='semantic')
img, smnt = dataset[0] img, smnt = dataset[0]
...@@ -34,41 +34,41 @@ class Cityscapes(data.Dataset): ...@@ -34,41 +34,41 @@ class Cityscapes(data.Dataset):
Get multiple targets Get multiple targets
.. code-block:: python .. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine', dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon']) target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0] img, (inst, col, poly) = dataset[0]
Validate on the "gtCoarse" set Validate on the "coarse" set
.. code-block:: python .. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse', dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic') target_type='semantic')
img, smnt = dataset[0] img, smnt = dataset[0]
""" """
def __init__(self, root, split='train', mode='gtFine', target_type='instance', def __init__(self, root, split='train', mode='fine', target_type='instance',
transform=None, target_transform=None): transform=None, target_transform=None):
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit', split) self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
self.targets_dir = os.path.join(self.root, mode, split) self.targets_dir = os.path.join(self.root, self.mode, split)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.target_type = target_type self.target_type = target_type
self.split = split self.split = split
self.mode = mode
self.images = [] self.images = []
self.targets = [] self.targets = []
if mode not in ['gtFine', 'gtCoarse']: if mode not in ['fine', 'coarse']:
raise ValueError('Invalid mode! Please use mode="gtFine" or mode="gtCoarse"') raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"')
if mode == 'gtFine' and split not in ['train', 'test', 'val']: if mode == 'fine' and split not in ['train', 'test', 'val']:
raise ValueError('Invalid split for mode "gtFine"! Please use split="train", split="test"' raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"'
' or split="val"') ' or split="val"')
elif mode == 'gtCoarse' and split not in ['train', 'train_extra', 'val']: elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']:
raise ValueError('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"' raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"'
' or split="val"') ' or split="val"')
if not isinstance(target_type, list): if not isinstance(target_type, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment