Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3f70e3c4
Unverified
Commit
3f70e3c4
authored
Aug 03, 2020
by
Philip Meier
Committed by
GitHub
Aug 03, 2020
Browse files
add typehints for torchvision.datasets.mnist (#2532)
parent
203a7841
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
34 deletions
+45
-34
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+45
-34
No files found.
torchvision/datasets/mnist.py
View file @
3f70e3c4
...
...
@@ -7,6 +7,7 @@ import numpy as np
import
torch
import
codecs
import
string
from
typing
import
Any
,
Callable
,
Dict
,
IO
,
List
,
Optional
,
Tuple
,
Union
from
.utils
import
download_url
,
download_and_extract_archive
,
extract_archive
,
\
verify_str_arg
...
...
@@ -60,8 +61,14 @@ class MNIST(VisionDataset):
warnings
.
warn
(
"test_data has been renamed data"
)
return
self
.
data
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
:
str
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
MNIST
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
train
=
train
# training set or test set
...
...
@@ -79,7 +86,7 @@ class MNIST(VisionDataset):
data_file
=
self
.
test_file
self
.
data
,
self
.
targets
=
torch
.
load
(
os
.
path
.
join
(
self
.
processed_folder
,
data_file
))
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
Args:
index (int): Index
...
...
@@ -101,28 +108,28 @@ class MNIST(VisionDataset):
return
img
,
target
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
@
property
def
raw_folder
(
self
):
def
raw_folder
(
self
)
->
str
:
return
os
.
path
.
join
(
self
.
root
,
self
.
__class__
.
__name__
,
'raw'
)
@
property
def
processed_folder
(
self
):
def
processed_folder
(
self
)
->
str
:
return
os
.
path
.
join
(
self
.
root
,
self
.
__class__
.
__name__
,
'processed'
)
@
property
def
class_to_idx
(
self
):
def
class_to_idx
(
self
)
->
Dict
[
str
,
int
]
:
return
{
_class
:
i
for
i
,
_class
in
enumerate
(
self
.
classes
)}
def
_check_exists
(
self
):
def
_check_exists
(
self
)
->
bool
:
return
(
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
processed_folder
,
self
.
training_file
))
and
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
processed_folder
,
self
.
test_file
)))
def
download
(
self
):
def
download
(
self
)
->
None
:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
if
self
.
_check_exists
():
...
...
@@ -154,7 +161,7 @@ class MNIST(VisionDataset):
print
(
'Done!'
)
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
...
...
@@ -251,7 +258,7 @@ class EMNIST(MNIST):
'mnist'
:
list
(
string
.
digits
),
}
def
__init__
(
self
,
root
,
split
,
**
kwargs
)
:
def
__init__
(
self
,
root
:
str
,
split
:
str
,
**
kwargs
:
Any
)
->
None
:
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
training_file
=
self
.
_training_file
(
split
)
self
.
test_file
=
self
.
_test_file
(
split
)
...
...
@@ -259,14 +266,14 @@ class EMNIST(MNIST):
self
.
classes
=
self
.
classes_split_dict
[
self
.
split
]
@
staticmethod
def
_training_file
(
split
):
def
_training_file
(
split
)
->
str
:
return
'training_{}.pt'
.
format
(
split
)
@
staticmethod
def
_test_file
(
split
):
def
_test_file
(
split
)
->
str
:
return
'test_{}.pt'
.
format
(
split
)
def
download
(
self
):
def
download
(
self
)
->
None
:
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import
shutil
...
...
@@ -343,7 +350,7 @@ class QMNIST(MNIST):
'test50k'
:
'test'
,
'nist'
:
'nist'
}
resources
=
{
# type: ignore[assignment]
resources
:
Dict
[
str
,
List
[
Tuple
[
str
,
str
]]]
=
{
# type: ignore[assignment]
'train'
:
[(
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz'
,
'ed72d4157d28c017586c42bc6afe6370'
),
(
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz'
,
...
...
@@ -360,7 +367,10 @@ class QMNIST(MNIST):
classes
=
[
'0 - zero'
,
'1 - one'
,
'2 - two'
,
'3 - three'
,
'4 - four'
,
'5 - five'
,
'6 - six'
,
'7 - seven'
,
'8 - eight'
,
'9 - nine'
]
def
__init__
(
self
,
root
,
what
=
None
,
compat
=
True
,
train
=
True
,
**
kwargs
):
def
__init__
(
self
,
root
:
str
,
what
:
Optional
[
str
]
=
None
,
compat
:
bool
=
True
,
train
:
bool
=
True
,
**
kwargs
:
Any
)
->
None
:
if
what
is
None
:
what
=
'train'
if
train
else
'test'
self
.
what
=
verify_str_arg
(
what
,
"what"
,
tuple
(
self
.
subsets
.
keys
()))
...
...
@@ -370,7 +380,7 @@ class QMNIST(MNIST):
self
.
test_file
=
self
.
data_file
super
(
QMNIST
,
self
).
__init__
(
root
,
train
,
**
kwargs
)
def
download
(
self
):
def
download
(
self
)
->
None
:
"""Download the QMNIST data if it doesn't exist in processed_folder already.
Note that we only download what has been asked for (argument 'what').
"""
...
...
@@ -405,7 +415,7 @@ class QMNIST(MNIST):
with
open
(
os
.
path
.
join
(
self
.
processed_folder
,
self
.
data_file
),
'wb'
)
as
f
:
torch
.
save
((
data
,
targets
),
f
)
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
# redefined to handle the compat flag
img
,
target
=
self
.
data
[
index
],
self
.
targets
[
index
]
img
=
Image
.
fromarray
(
img
.
numpy
(),
mode
=
'L'
)
...
...
@@ -417,15 +427,15 @@ class QMNIST(MNIST):
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
self
.
what
)
def
get_int
(
b
)
:
def
get_int
(
b
:
bytes
)
->
int
:
return
int
(
codecs
.
encode
(
b
,
'hex'
),
16
)
def
open_maybe_compressed_file
(
path
)
:
def
open_maybe_compressed_file
(
path
:
Union
[
str
,
IO
])
->
IO
:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
...
...
@@ -440,19 +450,20 @@ def open_maybe_compressed_file(path):
return
open
(
path
,
'rb'
)
def
read_sn3_pascalvincent_tensor
(
path
,
strict
=
True
):
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if
not
hasattr
(
read_sn3_pascalvincent_tensor
,
'typemap'
):
read_sn3_pascalvincent_tensor
.
typemap
=
{
SN3_PASCALVINCENT_TYPEMAP
=
{
8
:
(
torch
.
uint8
,
np
.
uint8
,
np
.
uint8
),
9
:
(
torch
.
int8
,
np
.
int8
,
np
.
int8
),
11
:
(
torch
.
int16
,
np
.
dtype
(
'>i2'
),
'i2'
),
12
:
(
torch
.
int32
,
np
.
dtype
(
'>i4'
),
'i4'
),
13
:
(
torch
.
float32
,
np
.
dtype
(
'>f4'
),
'f4'
),
14
:
(
torch
.
float64
,
np
.
dtype
(
'>f8'
),
'f8'
)}
14
:
(
torch
.
float64
,
np
.
dtype
(
'>f8'
),
'f8'
)
}
def
read_sn3_pascalvincent_tensor
(
path
:
Union
[
str
,
IO
],
strict
:
bool
=
True
)
->
torch
.
Tensor
:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with
open_maybe_compressed_file
(
path
)
as
f
:
data
=
f
.
read
()
...
...
@@ -462,14 +473,14 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
ty
=
magic
//
256
assert
nd
>=
1
and
nd
<=
3
assert
ty
>=
8
and
ty
<=
14
m
=
read_sn3_pascalvincent_tensor
.
typemap
[
ty
]
m
=
SN3_PASCALVINCENT_TYPEMAP
[
ty
]
s
=
[
get_int
(
data
[
4
*
(
i
+
1
):
4
*
(
i
+
2
)])
for
i
in
range
(
nd
)]
parsed
=
np
.
frombuffer
(
data
,
dtype
=
m
[
1
],
offset
=
(
4
*
(
nd
+
1
)))
assert
parsed
.
shape
[
0
]
==
np
.
prod
(
s
)
or
not
strict
return
torch
.
from_numpy
(
parsed
.
astype
(
m
[
2
],
copy
=
False
)).
view
(
*
s
)
def
read_label_file
(
path
)
:
def
read_label_file
(
path
:
str
)
->
torch
.
Tensor
:
with
open
(
path
,
'rb'
)
as
f
:
x
=
read_sn3_pascalvincent_tensor
(
f
,
strict
=
False
)
assert
(
x
.
dtype
==
torch
.
uint8
)
...
...
@@ -477,7 +488,7 @@ def read_label_file(path):
return
x
.
long
()
def
read_image_file
(
path
)
:
def
read_image_file
(
path
:
str
)
->
torch
.
Tensor
:
with
open
(
path
,
'rb'
)
as
f
:
x
=
read_sn3_pascalvincent_tensor
(
f
,
strict
=
False
)
assert
(
x
.
dtype
==
torch
.
uint8
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment