Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
0acbf663
"docs/vscode:/vscode.git/clone" did not exist on "4db08045baa250148b1e176e9ac1d5797affcd75"
Unverified
Commit
0acbf663
authored
Aug 03, 2020
by
Philip Meier
Committed by
GitHub
Aug 03, 2020
Browse files
add typehints for torchvision.datasets.usps (#2538)
parent
3f70e3c4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
5 deletions
+12
-5
torchvision/datasets/usps.py
torchvision/datasets/usps.py
+12
-5
No files found.
torchvision/datasets/usps.py
View file @
0acbf663
from
PIL
import
Image
from
PIL
import
Image
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
typing
import
Any
,
Callable
,
cast
,
Optional
,
Tuple
from
.utils
import
download_url
from
.utils
import
download_url
from
.vision
import
VisionDataset
from
.vision
import
VisionDataset
...
@@ -36,8 +37,14 @@ class USPS(VisionDataset):
...
@@ -36,8 +37,14 @@ class USPS(VisionDataset):
],
],
}
}
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
def
__init__
(
download
=
False
):
self
,
root
:
str
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
USPS
,
self
).
__init__
(
root
,
transform
=
transform
,
super
(
USPS
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
target_transform
=
target_transform
)
split
=
'train'
if
train
else
'test'
split
=
'train'
if
train
else
'test'
...
@@ -52,13 +59,13 @@ class USPS(VisionDataset):
...
@@ -52,13 +59,13 @@ class USPS(VisionDataset):
raw_data
=
[
line
.
decode
().
split
()
for
line
in
fp
.
readlines
()]
raw_data
=
[
line
.
decode
().
split
()
for
line
in
fp
.
readlines
()]
imgs
=
[[
x
.
split
(
':'
)[
-
1
]
for
x
in
data
[
1
:]]
for
data
in
raw_data
]
imgs
=
[[
x
.
split
(
':'
)[
-
1
]
for
x
in
data
[
1
:]]
for
data
in
raw_data
]
imgs
=
np
.
asarray
(
imgs
,
dtype
=
np
.
float32
).
reshape
((
-
1
,
16
,
16
))
imgs
=
np
.
asarray
(
imgs
,
dtype
=
np
.
float32
).
reshape
((
-
1
,
16
,
16
))
imgs
=
((
imgs
+
1
)
/
2
*
255
).
astype
(
dtype
=
np
.
uint8
)
imgs
=
((
cast
(
np
.
ndarray
,
imgs
)
+
1
)
/
2
*
255
).
astype
(
dtype
=
np
.
uint8
)
targets
=
[
int
(
d
[
0
])
-
1
for
d
in
raw_data
]
targets
=
[
int
(
d
[
0
])
-
1
for
d
in
raw_data
]
self
.
data
=
imgs
self
.
data
=
imgs
self
.
targets
=
targets
self
.
targets
=
targets
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
"""
Args:
Args:
index (int): Index
index (int): Index
...
@@ -80,5 +87,5 @@ class USPS(VisionDataset):
...
@@ -80,5 +87,5 @@ class USPS(VisionDataset):
return
img
,
target
return
img
,
target
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
return
len
(
self
.
data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment