Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1a6148d4
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a637724052d3e74965af61e79d005433848cfea2"
Unverified
Commit
1a6148d4
authored
Aug 03, 2020
by
Philip Meier
Committed by
GitHub
Aug 03, 2020
Browse files
add typehints for torchvision.datasets.svhn (#2539)
parent
7c1ed419
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
7 deletions
+14
-7
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+14
-7
No files found.
torchvision/datasets/svhn.py
View file @
1a6148d4
...
@@ -3,6 +3,7 @@ from PIL import Image
...
@@ -3,6 +3,7 @@ from PIL import Image
import
os
import
os
import
os.path
import
os.path
import
numpy
as
np
import
numpy
as
np
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
.utils
import
download_url
,
check_integrity
,
verify_str_arg
from
.utils
import
download_url
,
check_integrity
,
verify_str_arg
...
@@ -39,8 +40,14 @@ class SVHN(VisionDataset):
...
@@ -39,8 +40,14 @@ class SVHN(VisionDataset):
'extra'
:
[
"http://ufldl.stanford.edu/housenumbers/extra_32x32.mat"
,
'extra'
:
[
"http://ufldl.stanford.edu/housenumbers/extra_32x32.mat"
,
"extra_32x32.mat"
,
"a93ce644f1a588dc4d68dda5feec44a7"
]}
"extra_32x32.mat"
,
"a93ce644f1a588dc4d68dda5feec44a7"
]}
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
def
__init__
(
download
=
False
):
self
,
root
:
str
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
SVHN
,
self
).
__init__
(
root
,
transform
=
transform
,
super
(
SVHN
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
target_transform
=
target_transform
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
tuple
(
self
.
split_list
.
keys
()))
self
.
split
=
verify_str_arg
(
split
,
"split"
,
tuple
(
self
.
split_list
.
keys
()))
...
@@ -75,7 +82,7 @@ class SVHN(VisionDataset):
...
@@ -75,7 +82,7 @@ class SVHN(VisionDataset):
np
.
place
(
self
.
labels
,
self
.
labels
==
10
,
0
)
np
.
place
(
self
.
labels
,
self
.
labels
==
10
,
0
)
self
.
data
=
np
.
transpose
(
self
.
data
,
(
3
,
2
,
0
,
1
))
self
.
data
=
np
.
transpose
(
self
.
data
,
(
3
,
2
,
0
,
1
))
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
"""
Args:
Args:
index (int): Index
index (int): Index
...
@@ -97,18 +104,18 @@ class SVHN(VisionDataset):
...
@@ -97,18 +104,18 @@ class SVHN(VisionDataset):
return
img
,
target
return
img
,
target
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
return
len
(
self
.
data
)
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
root
=
self
.
root
root
=
self
.
root
md5
=
self
.
split_list
[
self
.
split
][
2
]
md5
=
self
.
split_list
[
self
.
split
][
2
]
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
return
check_integrity
(
fpath
,
md5
)
return
check_integrity
(
fpath
,
md5
)
def
download
(
self
):
def
download
(
self
)
->
None
:
md5
=
self
.
split_list
[
self
.
split
][
2
]
md5
=
self
.
split_list
[
self
.
split
][
2
]
download_url
(
self
.
url
,
self
.
root
,
self
.
filename
,
md5
)
download_url
(
self
.
url
,
self
.
root
,
self
.
filename
,
md5
)
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
return
"Split: {split}"
.
format
(
**
self
.
__dict__
)
return
"Split: {split}"
.
format
(
**
self
.
__dict__
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment