Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
885e3c20
Commit
885e3c20
authored
Dec 25, 2018
by
Michael Kösel
Committed by
Francisco Massa
Dec 25, 2018
Browse files
Support for returning multiple targets (#700)
parent
8ce00704
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
14 deletions
+53
-14
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+53
-14
No files found.
torchvision/datasets/cityscapes.py
View file @
885e3c20
...
@@ -7,18 +7,45 @@ from PIL import Image
...
@@ -7,18 +7,45 @@ from PIL import Image
class
Cityscapes
(
data
.
Dataset
):
class
Cityscapes
(
data
.
Dataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
otherwise ``train``, ``train_extra`` or ``val``
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
target_type (string
or list
, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``
or ``color``
. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
target and transforms it.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "gtCoarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
target_type='semantic')
img, smnt = dataset[0]
"""
"""
def
__init__
(
self
,
root
,
split
=
'train'
,
mode
=
'gtFine'
,
target_type
=
'instance'
,
def
__init__
(
self
,
root
,
split
=
'train'
,
mode
=
'gtFine'
,
target_type
=
'instance'
,
...
@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
...
@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
raise
ValueError
(
'Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
raise
ValueError
(
'Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
' or split="val"'
)
' or split="val"'
)
if
target_type
not
in
[
'instance'
,
'semantic'
,
'polygon'
,
'color'
]:
if
not
isinstance
(
target_type
,
list
):
raise
ValueError
(
'Invalid value for "target_type"! Please use target_type="instance",'
self
.
target_type
=
[
target_type
]
' target_type="semantic", target_type="polygon" or target_type="color"'
)
if
not
all
(
t
in
[
'instance'
,
'semantic'
,
'polygon'
,
'color'
]
for
t
in
self
.
target_type
):
raise
ValueError
(
'Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"'
)
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
raise
RuntimeError
(
'Dataset not found or incomplete. Please make sure all required folders for the'
raise
RuntimeError
(
'Dataset not found or incomplete. Please make sure all required folders for the'
...
@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
...
@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
for
file_name
in
os
.
listdir
(
img_dir
):
target_name
=
'{}_{}'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
target_types
=
[]
self
.
_get_target_suffix
(
self
.
mode
,
self
.
target_type
))
for
t
in
self
.
target_type
:
target_name
=
'{}_{}'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
_get_target_suffix
(
self
.
mode
,
t
))
target_types
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
)
)
self
.
targets
.
append
(
target_types
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
"""
"""
Args:
Args:
index (int): Index
index (int): Index
Returns:
Returns:
tuple: (image, target) where target is a
json object if target_type="polygon",
tuple: (image, target) where target is a
tuple of all target types if target_type is a list with more
otherwi
se the image segmentation.
than one item. Otherwise target is a json object if target_type="polygon", el
se the image segmentation.
"""
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
if
self
.
target_type
==
'polygon'
:
targets
=
[]
target
=
self
.
_load_json
(
self
.
targets
[
index
])
for
i
,
t
in
enumerate
(
self
.
target_type
):
else
:
if
t
==
'polygon'
:
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
self
.
_load_json
(
self
.
targets
[
index
][
i
])
else
:
target
=
Image
.
open
(
self
.
targets
[
index
][
i
])
targets
.
append
(
target
)
target
=
tuple
(
targets
)
if
len
(
targets
)
>
1
else
targets
[
0
]
if
self
.
transform
:
if
self
.
transform
:
image
=
self
.
transform
(
image
)
image
=
self
.
transform
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment