Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
262d6177
"git@developer.sourcefind.cn:OpenDAS/dcu_env_check.git" did not exist on "b46b6eb2aa0b3c57ef9248868195662240cdcaa4"
Unverified
Commit
262d6177
authored
Aug 03, 2020
by
Philip Meier
Committed by
GitHub
Aug 03, 2020
Browse files
add typehints for torchvision.datasets.omniglot (#2533)
parent
ec9c7a54
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
torchvision/datasets/omniglot.py
torchvision/datasets/omniglot.py
+17
-10
No files found.
torchvision/datasets/omniglot.py
View file @
262d6177
from
PIL
import
Image
from
os.path
import
join
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
.vision
import
VisionDataset
from
.utils
import
download_and_extract_archive
,
check_integrity
,
list_dir
,
list_files
...
...
@@ -27,8 +28,14 @@ class Omniglot(VisionDataset):
'images_evaluation'
:
'6b91aef0f799c5bb55b94e3f2daec811'
}
def
__init__
(
self
,
root
,
background
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
:
str
,
background
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
Omniglot
,
self
).
__init__
(
join
(
root
,
self
.
folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
background
=
background
...
...
@@ -42,16 +49,16 @@ class Omniglot(VisionDataset):
self
.
target_folder
=
join
(
self
.
root
,
self
.
_get_target_folder
())
self
.
_alphabets
=
list_dir
(
self
.
target_folder
)
self
.
_characters
=
sum
([[
join
(
a
,
c
)
for
c
in
list_dir
(
join
(
self
.
target_folder
,
a
))]
for
a
in
self
.
_alphabets
],
[])
self
.
_characters
:
List
[
str
]
=
sum
([[
join
(
a
,
c
)
for
c
in
list_dir
(
join
(
self
.
target_folder
,
a
))]
for
a
in
self
.
_alphabets
],
[])
self
.
_character_images
=
[[(
image
,
idx
)
for
image
in
list_files
(
join
(
self
.
target_folder
,
character
),
'.png'
)]
for
idx
,
character
in
enumerate
(
self
.
_characters
)]
self
.
_flat_character_images
=
sum
(
self
.
_character_images
,
[])
self
.
_flat_character_images
:
List
[
Tuple
[
str
,
int
]]
=
sum
(
self
.
_character_images
,
[])
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_flat_character_images
)
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
Args:
index (int): Index
...
...
@@ -71,13 +78,13 @@ class Omniglot(VisionDataset):
return
image
,
character_class
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
zip_filename
=
self
.
_get_target_folder
()
if
not
check_integrity
(
join
(
self
.
root
,
zip_filename
+
'.zip'
),
self
.
zips_md5
[
zip_filename
]):
return
False
return
True
def
download
(
self
):
def
download
(
self
)
->
None
:
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
...
...
@@ -87,5 +94,5 @@ class Omniglot(VisionDataset):
url
=
self
.
download_url_prefix
+
'/'
+
zip_filename
download_and_extract_archive
(
url
,
self
.
root
,
filename
=
zip_filename
,
md5
=
self
.
zips_md5
[
filename
])
def
_get_target_folder
(
self
):
def
_get_target_folder
(
self
)
->
str
:
return
'images_background'
if
self
.
background
else
'images_evaluation'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment