Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
47f80acc
Unverified
Commit
47f80acc
authored
Jul 31, 2020
by
Philip Meier
Committed by
GitHub
Jul 31, 2020
Browse files
cityscapes (#2525)
parent
15bd87f2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
8 deletions
+17
-8
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+17
-8
No files found.
torchvision/datasets/cityscapes.py
View file @
47f80acc
...
...
@@ -2,6 +2,7 @@ import json
import
os
from
collections
import
namedtuple
import
zipfile
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
from
.utils
import
extract_archive
,
verify_str_arg
,
iterable_to_str
from
.vision
import
VisionDataset
...
...
@@ -98,8 +99,16 @@ class Cityscapes(VisionDataset):
CityscapesClass
(
'license plate'
,
-
1
,
-
1
,
'vehicle'
,
7
,
False
,
True
,
(
0
,
0
,
142
)),
]
def
__init__
(
self
,
root
,
split
=
'train'
,
mode
=
'fine'
,
target_type
=
'instance'
,
transform
=
None
,
target_transform
=
None
,
transforms
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
mode
:
str
=
"fine"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"instance"
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
'gtFine'
if
mode
==
'fine'
else
'gtCoarse'
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
'leftImg8bit'
,
split
)
...
...
@@ -157,7 +166,7 @@ class Cityscapes(VisionDataset):
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
target_types
)
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
Args:
index (int): Index
...
...
@@ -168,7 +177,7 @@ class Cityscapes(VisionDataset):
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
targets
=
[]
targets
:
Any
=
[]
for
i
,
t
in
enumerate
(
self
.
target_type
):
if
t
==
'polygon'
:
target
=
self
.
_load_json
(
self
.
targets
[
index
][
i
])
...
...
@@ -184,19 +193,19 @@ class Cityscapes(VisionDataset):
return
image
,
target
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
images
)
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
lines
=
[
"Split: {split}"
,
"Mode: {mode}"
,
"Type: {target_type}"
]
return
'
\n
'
.
join
(
lines
).
format
(
**
self
.
__dict__
)
def
_load_json
(
self
,
path
)
:
def
_load_json
(
self
,
path
:
str
)
->
Dict
[
str
,
Any
]
:
with
open
(
path
,
'r'
)
as
file
:
data
=
json
.
load
(
file
)
return
data
def
_get_target_suffix
(
self
,
mode
,
target_type
)
:
def
_get_target_suffix
(
self
,
mode
:
str
,
target_type
:
str
)
->
str
:
if
target_type
==
'instance'
:
return
'{}_instanceIds.png'
.
format
(
mode
)
elif
target_type
==
'semantic'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment