Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
885e3c20
"vscode:/vscode.git/clone" did not exist on "1bdda8cb6e111903aa29d75dc2f33498f5df533a"
Commit
885e3c20
authored
Dec 25, 2018
by
Michael Kösel
Committed by
Francisco Massa
Dec 25, 2018
Browse files
Support for returning multiple targets (#700)
parent
8ce00704
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
14 deletions
+53
-14
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+53
-14
No files found.
torchvision/datasets/cityscapes.py
View file @
885e3c20
...
@@ -7,18 +7,45 @@ from PIL import Image
...
@@ -7,18 +7,45 @@ from PIL import Image
class
Cityscapes
(
data
.
Dataset
):
class
Cityscapes
(
data
.
Dataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
otherwise ``train``, ``train_extra`` or ``val``
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
target_type (string, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
target_type (string
or list
, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``
or ``color``
. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
target and transforms it.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "gtCoarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
target_type='semantic')
img, smnt = dataset[0]
"""
"""
def
__init__
(
self
,
root
,
split
=
'train'
,
mode
=
'gtFine'
,
target_type
=
'instance'
,
def
__init__
(
self
,
root
,
split
=
'train'
,
mode
=
'gtFine'
,
target_type
=
'instance'
,
...
@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
...
@@ -44,9 +71,12 @@ class Cityscapes(data.Dataset):
raise
ValueError
(
'Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
raise
ValueError
(
'Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
' or split="val"'
)
' or split="val"'
)
if
target_type
not
in
[
'instance'
,
'semantic'
,
'polygon'
,
'color'
]:
if
not
isinstance
(
target_type
,
list
):
raise
ValueError
(
'Invalid value for "target_type"! Please use target_type="instance",'
self
.
target_type
=
[
target_type
]
' target_type="semantic", target_type="polygon" or target_type="color"'
)
if
not
all
(
t
in
[
'instance'
,
'semantic'
,
'polygon'
,
'color'
]
for
t
in
self
.
target_type
):
raise
ValueError
(
'Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"'
)
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
raise
RuntimeError
(
'Dataset not found or incomplete. Please make sure all required folders for the'
raise
RuntimeError
(
'Dataset not found or incomplete. Please make sure all required folders for the'
...
@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
...
@@ -56,27 +86,36 @@ class Cityscapes(data.Dataset):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
for
file_name
in
os
.
listdir
(
img_dir
):
target_types
=
[]
for
t
in
self
.
target_type
:
target_name
=
'{}_{}'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
target_name
=
'{}_{}'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
_get_target_suffix
(
self
.
mode
,
self
.
target_type
))
self
.
_get_target_suffix
(
self
.
mode
,
t
))
target_types
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
)
)
self
.
targets
.
append
(
target_types
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
"""
"""
Args:
Args:
index (int): Index
index (int): Index
Returns:
Returns:
tuple: (image, target) where target is a
json object if target_type="polygon",
tuple: (image, target) where target is a
tuple of all target types if target_type is a list with more
otherwi
se the image segmentation.
than one item. Otherwise target is a json object if target_type="polygon", el
se the image segmentation.
"""
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
if
self
.
target_type
==
'polygon'
:
targets
=
[]
target
=
self
.
_load_json
(
self
.
targets
[
index
])
for
i
,
t
in
enumerate
(
self
.
target_type
):
if
t
==
'polygon'
:
target
=
self
.
_load_json
(
self
.
targets
[
index
][
i
])
else
:
else
:
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
Image
.
open
(
self
.
targets
[
index
][
i
])
targets
.
append
(
target
)
target
=
tuple
(
targets
)
if
len
(
targets
)
>
1
else
targets
[
0
]
if
self
.
transform
:
if
self
.
transform
:
image
=
self
.
transform
(
image
)
image
=
self
.
transform
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment