Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
8a64dbcd
"src/vscode:/vscode.git/clone" did not exist on "4e59bcc680ba4f68bc8d45249db5fe7a078413db"
Unverified
Commit
8a64dbcd
authored
Jun 07, 2019
by
Francisco Massa
Committed by
GitHub
Jun 07, 2019
Browse files
Mock MNIST download for less flaky tests (#1004)
parent
de387e8c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
7 deletions
+51
-7
test/test_datasets.py
test/test_datasets.py
+51
-7
No files found.
test/test_datasets.py
View file @
8a64dbcd
...
...
@@ -5,6 +5,7 @@ import tempfile
import
unittest
import
mock
import
PIL
import
torch
import
torchvision
FAKEDATA_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
...
...
@@ -23,6 +24,41 @@ def tmp_dir(src=None, **kwargs):
shutil
.
rmtree
(
tmp_dir
)
@
contextlib
.
contextmanager
def
get_mnist_data
(
num_images
,
cls_name
,
**
kwargs
):
def
_encode
(
v
):
return
torch
.
tensor
(
v
,
dtype
=
torch
.
int32
).
numpy
().
tobytes
()[::
-
1
]
def
_make_image_file
(
filename
,
num_images
):
img
=
torch
.
randint
(
0
,
255
,
size
=
(
28
*
28
*
num_images
,),
dtype
=
torch
.
uint8
)
with
open
(
filename
,
"wb"
)
as
f
:
f
.
write
(
_encode
(
2051
))
# magic header
f
.
write
(
_encode
(
num_images
))
f
.
write
(
_encode
(
28
))
f
.
write
(
_encode
(
28
))
f
.
write
(
img
.
numpy
().
tobytes
())
def
_make_label_file
(
filename
,
num_images
):
labels
=
torch
.
randint
(
0
,
10
,
size
=
(
num_images
,),
dtype
=
torch
.
uint8
)
with
open
(
filename
,
"wb"
)
as
f
:
f
.
write
(
_encode
(
2049
))
# magic header
f
.
write
(
_encode
(
num_images
))
f
.
write
(
labels
.
numpy
().
tobytes
())
tmp_dir
=
tempfile
.
mkdtemp
(
**
kwargs
)
raw_dir
=
os
.
path
.
join
(
tmp_dir
,
cls_name
,
"raw"
)
os
.
makedirs
(
raw_dir
)
_make_image_file
(
os
.
path
.
join
(
raw_dir
,
"train-images-idx3-ubyte"
),
num_images
)
_make_label_file
(
os
.
path
.
join
(
raw_dir
,
"train-labels-idx1-ubyte"
),
num_images
)
_make_image_file
(
os
.
path
.
join
(
raw_dir
,
"t10k-images-idx3-ubyte"
),
num_images
)
_make_label_file
(
os
.
path
.
join
(
raw_dir
,
"t10k-labels-idx1-ubyte"
),
num_images
)
try
:
yield
tmp_dir
finally
:
shutil
.
rmtree
(
tmp_dir
)
class
Tester
(
unittest
.
TestCase
):
def
test_imagefolder
(
self
):
...
...
@@ -70,25 +106,33 @@ class Tester(unittest.TestCase):
outputs
=
sorted
([
dataset
[
i
]
for
i
in
range
(
len
(
dataset
))])
self
.
assertEqual
(
imgs
,
outputs
)
def
test_mnist
(
self
):
with
tmp_dir
()
as
root
:
@
mock
.
patch
(
'torchvision.datasets.mnist.download_and_extract_archive'
)
def
test_mnist
(
self
,
mock_download_extract
):
num_examples
=
30
with
get_mnist_data
(
num_examples
,
"MNIST"
)
as
root
:
dataset
=
torchvision
.
datasets
.
MNIST
(
root
,
download
=
True
)
self
.
assertEqual
(
len
(
dataset
),
60000
)
self
.
assertEqual
(
len
(
dataset
),
num_examples
)
img
,
target
=
dataset
[
0
]
self
.
assertTrue
(
isinstance
(
img
,
PIL
.
Image
.
Image
))
self
.
assertTrue
(
isinstance
(
target
,
int
))
def
test_kmnist
(
self
):
with
tmp_dir
()
as
root
:
@
mock
.
patch
(
'torchvision.datasets.mnist.download_and_extract_archive'
)
def
test_kmnist
(
self
,
mock_download_extract
):
num_examples
=
30
with
get_mnist_data
(
num_examples
,
"KMNIST"
)
as
root
:
dataset
=
torchvision
.
datasets
.
KMNIST
(
root
,
download
=
True
)
img
,
target
=
dataset
[
0
]
self
.
assertEqual
(
len
(
dataset
),
num_examples
)
self
.
assertTrue
(
isinstance
(
img
,
PIL
.
Image
.
Image
))
self
.
assertTrue
(
isinstance
(
target
,
int
))
def
test_fashionmnist
(
self
):
with
tmp_dir
()
as
root
:
@
mock
.
patch
(
'torchvision.datasets.mnist.download_and_extract_archive'
)
def
test_fashionmnist
(
self
,
mock_download_extract
):
num_examples
=
30
with
get_mnist_data
(
num_examples
,
"FashionMNIST"
)
as
root
:
dataset
=
torchvision
.
datasets
.
FashionMNIST
(
root
,
download
=
True
)
img
,
target
=
dataset
[
0
]
self
.
assertEqual
(
len
(
dataset
),
num_examples
)
self
.
assertTrue
(
isinstance
(
img
,
PIL
.
Image
.
Image
))
self
.
assertTrue
(
isinstance
(
target
,
int
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment