Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
09600b90
"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "6c7d99928251e03249ac2c65006c7452f5676bb7"
Unverified
Commit
09600b90
authored
Mar 22, 2021
by
Philip Meier
Committed by
GitHub
Mar 22, 2021
Browse files
Enable custom default config for dataset tests (#3578)
parent
661d89fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
145 additions
and
49 deletions
+145
-49
test/datasets_utils.py
test/datasets_utils.py
+130
-36
test/test_datasets.py
test/test_datasets.py
+15
-13
No files found.
test/datasets_utils.py
View file @
09600b90
...
...
@@ -117,19 +117,45 @@ def requires_lazy_imports(*modules):
def
test_all_configs
(
test
):
"""Decorator to run test against all configurations.
Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is
provided as the first parameter:
Add this as decorator to an arbitrary test to run it against all configurations. This includes
:attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
The current configuration is provided as the first parameter for the test:
.. code-block::
@test_all_configs
@test_all_configs
()
def test_foo(self, config):
pass
.. note::
This will try to remove duplicate configurations. During this process it will not not preserve a potential
ordering of the configurations or an inner ordering of a configuration.
"""
def
maybe_remove_duplicates
(
configs
):
try
:
return
[
dict
(
config_
)
for
config_
in
set
(
tuple
(
sorted
(
config
.
items
()))
for
config
in
configs
)]
except
TypeError
:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out.
return
configs
@
functools
.
wraps
(
test
)
def
wrapper
(
self
):
for
config
in
self
.
CONFIGS
or
(
self
.
_DEFAULT_CONFIG
,):
configs
=
[]
if
self
.
DEFAULT_CONFIG
is
not
None
:
configs
.
append
(
self
.
DEFAULT_CONFIG
)
if
self
.
ADDITIONAL_CONFIGS
is
not
None
:
configs
.
extend
(
self
.
ADDITIONAL_CONFIGS
)
if
not
configs
:
configs
=
[
self
.
_KWARG_DEFAULTS
.
copy
()]
else
:
configs
=
maybe_remove_duplicates
(
configs
)
for
config
in
configs
:
with
self
.
subTest
(
**
config
):
test
(
self
,
config
)
...
...
@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
Optionally, you can overwrite the following class attributes:
- CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an
arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``. The first element will be used as default configuration.
- DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
not provide one.
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
...
...
@@ -218,22 +248,31 @@ class DatasetTestCase(unittest.TestCase):
DATASET_CLASS
=
None
FEATURE_TYPES
=
None
CONFIGS
=
None
DEFAULT_CONFIG
=
None
ADDITIONAL_CONFIGS
=
None
REQUIRED_PACKAGES
=
None
_DEFAULT_CONFIG
=
None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS
=
{
"transform"
,
"target_transform"
,
"transforms"
,
}
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
_SPECIAL_KWARGS
=
{
*
_TRANSFORM_KWARGS
,
"download"
,
}
# These fields are populated during setupClass() within _populate_private_class_attributes()
# This will be a dictionary containing all keyword arguments with their respective default values extracted from
# the dataset constructor.
_KWARG_DEFAULTS
=
None
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
_HAS_SPECIAL_KWARG
=
None
# These functions are disabled during dataset creation in create_dataset().
_CHECK_FUNCTIONS
=
{
"check_md5"
,
"check_integrity"
,
...
...
@@ -256,7 +295,8 @@ class DatasetTestCase(unittest.TestCase):
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
...
...
@@ -273,7 +313,8 @@ class DatasetTestCase(unittest.TestCase):
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Needs to return one of the following:
...
...
@@ -293,9 +334,16 @@ class DatasetTestCase(unittest.TestCase):
)
->
Iterator
[
Tuple
[
torchvision
.
datasets
.
VisionDataset
,
Dict
[
str
,
Any
]]]:
r
"""Create the dataset in a temporary directory.
The configuration passed to the dataset is populated to contain at least all parameters with default values.
For this the following order of precedence is used:
1. Parameters in :attr:`kwargs`.
2. Configuration in :attr:`config`.
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
4. Default parameters of the dataset.
Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
default configuration is used.
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
...
...
@@ -308,30 +356,33 @@ class DatasetTestCase(unittest.TestCase):
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
default_config
=
self
.
_DEFAULT_CONFIG
.
copy
()
if
config
is
not
None
:
default_config
.
update
(
config
)
config
=
default_config
if
patch_checks
is
None
:
patch_checks
=
inject_fake_data
special_kwargs
,
other_kwargs
=
self
.
_split_kwargs
(
kwargs
)
complete_config
=
self
.
_KWARG_DEFAULTS
.
copy
()
if
self
.
DEFAULT_CONFIG
:
complete_config
.
update
(
self
.
DEFAULT_CONFIG
)
if
config
:
complete_config
.
update
(
config
)
if
other_kwargs
:
complete_config
.
update
(
other_kwargs
)
if
"download"
in
self
.
_HAS_SPECIAL_KWARG
and
special_kwargs
.
get
(
"download"
,
False
):
# override download param to False param if its default is truthy
special_kwargs
[
"download"
]
=
False
config
.
update
(
other_kwargs
)
patchers
=
self
.
_patch_download_extract
()
if
patch_checks
:
patchers
.
update
(
self
.
_patch_checks
())
with
get_tmp_dir
()
as
tmpdir
:
args
=
self
.
dataset_args
(
tmpdir
,
config
)
info
=
self
.
_inject_fake_data
(
tmpdir
,
config
)
if
inject_fake_data
else
None
args
=
self
.
dataset_args
(
tmpdir
,
complete_
config
)
info
=
self
.
_inject_fake_data
(
tmpdir
,
complete_
config
)
if
inject_fake_data
else
None
with
self
.
_maybe_apply_patches
(
patchers
),
disable_console_output
():
dataset
=
self
.
DATASET_CLASS
(
*
args
,
**
config
,
**
special_kwargs
)
dataset
=
self
.
DATASET_CLASS
(
*
args
,
**
complete_
config
,
**
special_kwargs
)
yield
dataset
,
info
...
...
@@ -357,26 +408,69 @@ class DatasetTestCase(unittest.TestCase):
@
classmethod
def
_populate_private_class_attributes
(
cls
):
argspec
=
inspect
.
getfullargspec
(
cls
.
DATASET_CLASS
.
__init__
)
defaults
=
[]
for
cls_
in
cls
.
DATASET_CLASS
.
__mro__
:
if
cls_
is
torchvision
.
datasets
.
VisionDataset
:
break
cls
.
_DEFAULT_CONFIG
=
{
kwarg
:
default
for
kwarg
,
default
in
zip
(
argspec
.
args
[
-
len
(
argspec
.
defaults
):],
argspec
.
defaults
)
if
kwarg
not
in
cls
.
_SPECIAL_KWARGS
}
argspec
=
inspect
.
getfullargspec
(
cls_
.
__init__
)
cls
.
_HAS_SPECIAL_KWARG
=
{
name
for
name
in
cls
.
_SPECIAL_KWARGS
if
name
in
argspec
.
args
}
if
not
argspec
.
defaults
:
continue
defaults
.
append
(
{
kwarg
:
default
for
kwarg
,
default
in
zip
(
argspec
.
args
[
-
len
(
argspec
.
defaults
):],
argspec
.
defaults
)}
)
if
not
argspec
.
varkw
:
break
kwarg_defaults
=
dict
()
for
config
in
reversed
(
defaults
):
kwarg_defaults
.
update
(
config
)
has_special_kwargs
=
set
()
for
name
in
cls
.
_SPECIAL_KWARGS
:
if
name
not
in
kwarg_defaults
:
continue
del
kwarg_defaults
[
name
]
has_special_kwargs
.
add
(
name
)
cls
.
_KWARG_DEFAULTS
=
kwarg_defaults
cls
.
_HAS_SPECIAL_KWARG
=
has_special_kwargs
@
classmethod
def
_process_optional_public_class_attributes
(
cls
):
if
cls
.
REQUIRED_PACKAGES
is
not
None
:
try
:
for
pkg
in
cls
.
REQUIRED_PACKAGES
:
def
check_config
(
config
,
name
):
special_kwargs
=
tuple
(
f
"'
{
name
}
'"
for
name
in
cls
.
_SPECIAL_KWARGS
if
name
in
config
)
if
special_kwargs
:
raise
UsageError
(
f
"
{
name
}
contains a value for the parameter(s)
{
', '
.
join
(
special_kwargs
)
}
. "
f
"These are handled separately by the test case and should not be set here. "
f
"If you need to test some custom behavior regarding these parameters, "
f
"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
)
if
cls
.
DEFAULT_CONFIG
is
not
None
:
check_config
(
cls
.
DEFAULT_CONFIG
,
"DEFAULT_CONFIG"
)
if
cls
.
ADDITIONAL_CONFIGS
is
not
None
:
for
idx
,
config
in
enumerate
(
cls
.
ADDITIONAL_CONFIGS
):
check_config
(
config
,
f
"CONFIGS[
{
idx
}
]"
)
if
cls
.
REQUIRED_PACKAGES
:
missing_pkgs
=
[]
for
pkg
in
cls
.
REQUIRED_PACKAGES
:
try
:
importlib
.
import_module
(
pkg
)
except
ImportError
as
error
:
except
ImportError
:
missing_pkgs
.
append
(
f
"'
{
pkg
}
'"
)
if
missing_pkgs
:
raise
unittest
.
SkipTest
(
f
"The package
'
{
error
.
name
}
' is required to load the dataset '
{
cls
.
DATASET_CLASS
.
__name__
}
' but is
"
f
"not installed."
f
"The package
(s)
{
', '
.
join
(
missing_pkgs
)
}
are required to load the dataset
"
f
"
'
{
cls
.
DATASET_CLASS
.
__name__
}
', but are
not installed."
)
def
_split_kwargs
(
self
,
kwargs
):
...
...
test/test_datasets.py
View file @
09600b90
...
...
@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
Caltech101
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
int
,
np
.
ndarray
,
tuple
))
CONFIGS
=
datasets_utils
.
combinations_grid
(
target_type
=
(
"category"
,
"annotation"
,
[
"category"
,
"annotation"
]))
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
target_type
=
(
"category"
,
"annotation"
,
[
"category"
,
"annotation"
])
)
REQUIRED_PACKAGES
=
(
"scipy"
,)
def
inject_fake_data
(
self
,
tmpdir
,
config
):
...
...
@@ -466,7 +468,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class
WIDERFaceTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
WIDERFace
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
dict
,
type
(
None
)))
# test split returns None as target
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
,
'test'
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
,
'test'
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
widerface_dir
=
pathlib
.
Path
(
tmpdir
)
/
'widerface'
...
...
@@ -521,7 +523,7 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
ImageNet
REQUIRED_PACKAGES
=
(
'scipy'
,)
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
tmpdir
=
pathlib
.
Path
(
tmpdir
)
...
...
@@ -551,7 +553,7 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
CIFAR10
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
_VERSION_CONFIG
=
dict
(
base_folder
=
"cifar-10-batches-py"
,
...
...
@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
CelebA
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
torch
.
Tensor
,
int
,
tuple
,
type
(
None
)))
CONFIGS
=
datasets_utils
.
combinations_grid
(
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"valid"
,
"test"
,
"all"
),
target_type
=
(
"attr"
,
"identity"
,
"bbox"
,
"landmarks"
,
[
"attr"
,
"identity"
]),
)
...
...
@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
VOCSegmentation
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
)
CONFIGS
=
(
ADDITIONAL_
CONFIGS
=
(
*
datasets_utils
.
combinations_grid
(
year
=
[
f
"20
{
year
:
02
d
}
"
for
year
in
range
(
7
,
13
)],
image_set
=
(
"train"
,
"val"
,
"trainval"
)
),
...
...
@@ -929,7 +931,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
class
UCF101TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
DATASET_CLASS
=
datasets
.
UCF101
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
_VIDEO_FOLDER
=
"videos"
_ANNOTATIONS_FOLDER
=
"annotations"
...
...
@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
LSUN
REQUIRED_PACKAGES
=
(
"lmdb"
,)
CONFIGS
=
datasets_utils
.
combinations_grid
(
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
classes
=
(
"train"
,
"test"
,
"val"
,
[
"bedroom_train"
,
"church_outdoor_train"
])
)
...
...
@@ -1097,7 +1099,7 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
DATASET_CLASS
=
datasets
.
HMDB51
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
_VIDEO_FOLDER
=
"videos"
_SPLITS_FOLDER
=
"splits"
...
...
@@ -1157,7 +1159,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
class
OmniglotTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Omniglot
CONFIGS
=
datasets_utils
.
combinations_grid
(
background
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
background
=
(
True
,
False
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
target_folder
=
(
...
...
@@ -1237,7 +1239,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
class
USPSTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
USPS
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
num_images
=
2
if
config
[
"train"
]
else
1
...
...
@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES
=
(
"scipy.io"
,
"scipy.sparse"
)
CONFIGS
=
datasets_utils
.
combinations_grid
(
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
image_set
=
(
"train"
,
"val"
,
"train_noval"
),
mode
=
(
"boundaries"
,
"segmentation"
)
)
...
...
@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
_TRAIN_FEATURE_TYPES
=
(
torch
.
Tensor
,)
_TEST_FEATURE_TYPES
=
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
)
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
_NAME
=
"liberty"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment