Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
240792d4
Unverified
Commit
240792d4
authored
Mar 11, 2021
by
Nicolas Hug
Committed by
GitHub
Mar 11, 2021
Browse files
New tests for ImageNet dataset (#3543)
parent
814c4f08
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
82 deletions
+34
-82
test/datasets_utils.py
test/datasets_utils.py
+2
-1
test/fakedata_generation.py
test/fakedata_generation.py
+0
-70
test/test_datasets.py
test/test_datasets.py
+32
-11
No files found.
test/datasets_utils.py
View file @
240792d4
...
@@ -312,7 +312,8 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -312,7 +312,8 @@ class DatasetTestCase(unittest.TestCase):
patch_checks
=
inject_fake_data
patch_checks
=
inject_fake_data
special_kwargs
,
other_kwargs
=
self
.
_split_kwargs
(
kwargs
)
special_kwargs
,
other_kwargs
=
self
.
_split_kwargs
(
kwargs
)
if
"download"
in
self
.
_HAS_SPECIAL_KWARG
:
if
"download"
in
self
.
_HAS_SPECIAL_KWARG
and
special_kwargs
.
get
(
"download"
,
False
):
# override download param to False param if its default is truthy
special_kwargs
[
"download"
]
=
False
special_kwargs
[
"download"
]
=
False
config
.
update
(
other_kwargs
)
config
.
update
(
other_kwargs
)
...
...
test/fakedata_generation.py
View file @
240792d4
...
@@ -143,76 +143,6 @@ def cifar_root(version):
...
@@ -143,76 +143,6 @@ def cifar_root(version):
yield
root
yield
root
@
contextlib
.
contextmanager
def
imagenet_root
():
import
scipy.io
as
sio
WNID
=
'n01234567'
CLS
=
'fakedata'
def
_make_image
(
file
):
PIL
.
Image
.
fromarray
(
np
.
zeros
((
32
,
32
,
3
),
dtype
=
np
.
uint8
)).
save
(
file
)
def
_make_tar
(
archive
,
content
,
arcname
=
None
,
compress
=
False
):
mode
=
'w:gz'
if
compress
else
'w'
if
arcname
is
None
:
arcname
=
os
.
path
.
basename
(
content
)
with
tarfile
.
open
(
archive
,
mode
)
as
fh
:
fh
.
add
(
content
,
arcname
=
arcname
)
def
_make_train_archive
(
root
):
with
get_tmp_dir
()
as
tmp
:
wnid_dir
=
os
.
path
.
join
(
tmp
,
WNID
)
os
.
mkdir
(
wnid_dir
)
_make_image
(
os
.
path
.
join
(
wnid_dir
,
WNID
+
'_1.JPEG'
))
wnid_archive
=
wnid_dir
+
'.tar'
_make_tar
(
wnid_archive
,
wnid_dir
)
train_archive
=
os
.
path
.
join
(
root
,
'ILSVRC2012_img_train.tar'
)
_make_tar
(
train_archive
,
wnid_archive
)
def
_make_val_archive
(
root
):
with
get_tmp_dir
()
as
tmp
:
val_image
=
os
.
path
.
join
(
tmp
,
'ILSVRC2012_val_00000001.JPEG'
)
_make_image
(
val_image
)
val_archive
=
os
.
path
.
join
(
root
,
'ILSVRC2012_img_val.tar'
)
_make_tar
(
val_archive
,
val_image
)
def
_make_devkit_archive
(
root
):
with
get_tmp_dir
()
as
tmp
:
data_dir
=
os
.
path
.
join
(
tmp
,
'data'
)
os
.
mkdir
(
data_dir
)
meta_file
=
os
.
path
.
join
(
data_dir
,
'meta.mat'
)
synsets
=
np
.
core
.
records
.
fromarrays
([
(
0.0
,
1.0
),
(
WNID
,
''
),
(
CLS
,
''
),
(
'fakedata for the torchvision testsuite'
,
''
),
(
0.0
,
1.0
),
],
names
=
[
'ILSVRC2012_ID'
,
'WNID'
,
'words'
,
'gloss'
,
'num_children'
])
sio
.
savemat
(
meta_file
,
{
'synsets'
:
synsets
})
groundtruth_file
=
os
.
path
.
join
(
data_dir
,
'ILSVRC2012_validation_ground_truth.txt'
)
with
open
(
groundtruth_file
,
'w'
)
as
fh
:
fh
.
write
(
'0
\n
'
)
devkit_name
=
'ILSVRC2012_devkit_t12'
devkit_archive
=
os
.
path
.
join
(
root
,
devkit_name
+
'.tar.gz'
)
_make_tar
(
devkit_archive
,
tmp
,
arcname
=
devkit_name
,
compress
=
True
)
with
get_tmp_dir
()
as
root
:
_make_train_archive
(
root
)
_make_val_archive
(
root
)
_make_devkit_archive
(
root
)
yield
root
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
widerface_root
():
def
widerface_root
():
"""
"""
...
...
test/test_datasets.py
View file @
240792d4
...
@@ -10,7 +10,7 @@ from torch._utils_internal import get_file_path_2
...
@@ -10,7 +10,7 @@ from torch._utils_internal import get_file_path_2
import
torchvision
import
torchvision
from
torchvision.datasets
import
utils
from
torchvision.datasets
import
utils
from
common_utils
import
get_tmp_dir
from
common_utils
import
get_tmp_dir
from
fakedata_generation
import
mnist_root
,
imagenet_root
,
\
from
fakedata_generation
import
mnist_root
,
\
cityscapes_root
,
svhn_root
,
places365_root
,
widerface_root
,
stl10_root
cityscapes_root
,
svhn_root
,
places365_root
,
widerface_root
,
stl10_root
import
xml.etree.ElementTree
as
ET
import
xml.etree.ElementTree
as
ET
from
urllib.request
import
Request
,
urlopen
from
urllib.request
import
Request
,
urlopen
...
@@ -146,16 +146,6 @@ class Tester(DatasetTestcase):
...
@@ -146,16 +146,6 @@ class Tester(DatasetTestcase):
img
,
target
=
dataset
[
0
]
img
,
target
=
dataset
[
0
]
self
.
assertEqual
(
dataset
.
class_to_idx
[
dataset
.
classes
[
0
]],
target
)
self
.
assertEqual
(
dataset
.
class_to_idx
[
dataset
.
classes
[
0
]],
target
)
@
mock
.
patch
(
'torchvision.datasets.imagenet._verify_archive'
)
@
unittest
.
skipIf
(
not
HAS_SCIPY
,
"scipy unavailable"
)
def
test_imagenet
(
self
,
mock_verify
):
with
imagenet_root
()
as
root
:
dataset
=
torchvision
.
datasets
.
ImageNet
(
root
,
split
=
'train'
)
self
.
generic_classification_dataset_test
(
dataset
)
dataset
=
torchvision
.
datasets
.
ImageNet
(
root
,
split
=
'val'
)
self
.
generic_classification_dataset_test
(
dataset
)
@
mock
.
patch
(
'torchvision.datasets.WIDERFace._check_integrity'
)
@
mock
.
patch
(
'torchvision.datasets.WIDERFace._check_integrity'
)
@
unittest
.
skipIf
(
'win'
in
sys
.
platform
,
'temporarily disabled on Windows'
)
@
unittest
.
skipIf
(
'win'
in
sys
.
platform
,
'temporarily disabled on Windows'
)
def
test_widerface
(
self
,
mock_check_integrity
):
def
test_widerface
(
self
,
mock_check_integrity
):
...
@@ -490,6 +480,37 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -490,6 +480,37 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
return
num_images_per_category
*
len
(
categories
)
return
num_images_per_category
*
len
(
categories
)
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
ImageNet
REQUIRED_PACKAGES
=
(
'scipy'
,)
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
tmpdir
=
pathlib
.
Path
(
tmpdir
)
wnid
=
'n01234567'
if
config
[
'split'
]
==
'train'
:
num_examples
=
3
datasets_utils
.
create_image_folder
(
root
=
tmpdir
,
name
=
tmpdir
/
'train'
/
wnid
/
wnid
,
file_name_fn
=
lambda
image_idx
:
f
"
{
wnid
}
_
{
image_idx
}
.JPEG"
,
num_examples
=
num_examples
,
)
else
:
num_examples
=
1
datasets_utils
.
create_image_folder
(
root
=
tmpdir
,
name
=
tmpdir
/
'val'
/
wnid
,
file_name_fn
=
lambda
image_ifx
:
"ILSVRC2012_val_0000000{image_idx}.JPEG"
,
num_examples
=
num_examples
,
)
wnid_to_classes
=
{
wnid
:
[
1
]}
torch
.
save
((
wnid_to_classes
,
None
),
tmpdir
/
'meta.bin'
)
return
num_examples
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
CIFAR10
DATASET_CLASS
=
datasets
.
CIFAR10
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment