Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
885e3c20
Commit
885e3c20
authored
Dec 25, 2018
by
Michael Kösel
Committed by
Francisco Massa
Dec 25, 2018
Browse files
Support for returning multiple targets (#700)
parent
8ce00704
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
14 deletions
+53
-14
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+53
-14
No files found.
torchvision/datasets/cityscapes.py
View file @
885e3c20
...
...
@@ -7,18 +7,45 @@ from PIL import Image
class
Cityscapes
(
data
.
Dataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``
target_type (string
or list
, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``
. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "gtCoarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
target_type='semantic')
img, smnt = dataset[0]
"""
def
__init__
(
self
,
root
,
split
=
'train'
,
mode
=
'gtFine'
,
target_type
=
'instance'
,
...
...
@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
raise
ValueError
(
'Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
' or split="val"'
)
if
target_type
not
in
[
'instance'
,
'semantic'
,
'polygon'
,
'color'
]:
raise
ValueError
(
'Invalid value for "target_type"! Please use target_type="instance",'
' target_type="semantic", target_type="polygon" or target_type="color"'
)
if
not
isinstance
(
target_type
,
list
):
self
.
target_type
=
[
target_type
]
if
not
all
(
t
in
[
'instance'
,
'semantic'
,
'polygon'
,
'color'
]
for
t
in
self
.
target_type
):
raise
ValueError
(
'Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"'
)
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
raise
RuntimeError
(
'Dataset not found or incomplete. Please make sure all required folders for the'
...
...
@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
target_types
=
[]
for
t
in
self
.
target_type
:
target_name
=
'{}_{}'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
_get_target_suffix
(
self
.
mode
,
self
.
target_type
))
self
.
_get_target_suffix
(
self
.
mode
,
t
))
target_types
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
)
)
self
.
targets
.
append
(
target_types
)
def
__getitem__
(
self
,
index
):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a
json object if target_type="polygon",
otherwi
se the image segmentation.
tuple: (image, target) where target is a
tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", el
se the image segmentation.
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
if
self
.
target_type
==
'polygon'
:
target
=
self
.
_load_json
(
self
.
targets
[
index
])
targets
=
[]
for
i
,
t
in
enumerate
(
self
.
target_type
):
if
t
==
'polygon'
:
target
=
self
.
_load_json
(
self
.
targets
[
index
][
i
])
else
:
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
Image
.
open
(
self
.
targets
[
index
][
i
])
targets
.
append
(
target
)
target
=
tuple
(
targets
)
if
len
(
targets
)
>
1
else
targets
[
0
]
if
self
.
transform
:
image
=
self
.
transform
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment