Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
09600b90
Unverified
Commit
09600b90
authored
Mar 22, 2021
by
Philip Meier
Committed by
GitHub
Mar 22, 2021
Browse files
Enable custom default config for dataset tests (#3578)
parent
661d89fd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
145 additions
and
49 deletions
+145
-49
test/datasets_utils.py
test/datasets_utils.py
+130
-36
test/test_datasets.py
test/test_datasets.py
+15
-13
No files found.
test/datasets_utils.py
View file @
09600b90
...
@@ -117,19 +117,45 @@ def requires_lazy_imports(*modules):
...
@@ -117,19 +117,45 @@ def requires_lazy_imports(*modules):
def
test_all_configs
(
test
):
def
test_all_configs
(
test
):
"""Decorator to run test against all configurations.
"""Decorator to run test against all configurations.
Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is
Add this as decorator to an arbitrary test to run it against all configurations. This includes
provided as the first parameter:
:attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
The current configuration is provided as the first parameter for the test:
.. code-block::
.. code-block::
@test_all_configs
@test_all_configs
()
def test_foo(self, config):
def test_foo(self, config):
pass
pass
.. note::
This will try to remove duplicate configurations. During this process it will not not preserve a potential
ordering of the configurations or an inner ordering of a configuration.
"""
"""
def
maybe_remove_duplicates
(
configs
):
try
:
return
[
dict
(
config_
)
for
config_
in
set
(
tuple
(
sorted
(
config
.
items
()))
for
config
in
configs
)]
except
TypeError
:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out.
return
configs
@
functools
.
wraps
(
test
)
@
functools
.
wraps
(
test
)
def
wrapper
(
self
):
def
wrapper
(
self
):
for
config
in
self
.
CONFIGS
or
(
self
.
_DEFAULT_CONFIG
,):
configs
=
[]
if
self
.
DEFAULT_CONFIG
is
not
None
:
configs
.
append
(
self
.
DEFAULT_CONFIG
)
if
self
.
ADDITIONAL_CONFIGS
is
not
None
:
configs
.
extend
(
self
.
ADDITIONAL_CONFIGS
)
if
not
configs
:
configs
=
[
self
.
_KWARG_DEFAULTS
.
copy
()]
else
:
configs
=
maybe_remove_duplicates
(
configs
)
for
config
in
configs
:
with
self
.
subTest
(
**
config
):
with
self
.
subTest
(
**
config
):
test
(
self
,
config
)
test
(
self
,
config
)
...
@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
Optionally, you can overwrite the following class attributes:
Optionally, you can overwrite the following class attributes:
- CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an
- DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
``transforms``, or ``download``. The first element will be used as default configuration.
``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
not provide one.
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
available, the tests are skipped.
...
@@ -218,22 +248,31 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -218,22 +248,31 @@ class DatasetTestCase(unittest.TestCase):
DATASET_CLASS
=
None
DATASET_CLASS
=
None
FEATURE_TYPES
=
None
FEATURE_TYPES
=
None
CONFIGS
=
None
DEFAULT_CONFIG
=
None
ADDITIONAL_CONFIGS
=
None
REQUIRED_PACKAGES
=
None
REQUIRED_PACKAGES
=
None
_DEFAULT_CONFIG
=
None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS
=
{
_TRANSFORM_KWARGS
=
{
"transform"
,
"transform"
,
"target_transform"
,
"target_transform"
,
"transforms"
,
"transforms"
,
}
}
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
_SPECIAL_KWARGS
=
{
_SPECIAL_KWARGS
=
{
*
_TRANSFORM_KWARGS
,
*
_TRANSFORM_KWARGS
,
"download"
,
"download"
,
}
}
# These fields are populated during setupClass() within _populate_private_class_attributes()
# This will be a dictionary containing all keyword arguments with their respective default values extracted from
# the dataset constructor.
_KWARG_DEFAULTS
=
None
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
_HAS_SPECIAL_KWARG
=
None
_HAS_SPECIAL_KWARG
=
None
# These functions are disabled during dataset creation in create_dataset().
_CHECK_FUNCTIONS
=
{
_CHECK_FUNCTIONS
=
{
"check_md5"
,
"check_md5"
,
"check_integrity"
,
"check_integrity"
,
...
@@ -256,7 +295,8 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -256,7 +295,8 @@ class DatasetTestCase(unittest.TestCase):
Args:
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Returns:
Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
...
@@ -273,7 +313,8 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -273,7 +313,8 @@ class DatasetTestCase(unittest.TestCase):
Args:
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Needs to return one of the following:
Needs to return one of the following:
...
@@ -293,9 +334,16 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -293,9 +334,16 @@ class DatasetTestCase(unittest.TestCase):
)
->
Iterator
[
Tuple
[
torchvision
.
datasets
.
VisionDataset
,
Dict
[
str
,
Any
]]]:
)
->
Iterator
[
Tuple
[
torchvision
.
datasets
.
VisionDataset
,
Dict
[
str
,
Any
]]]:
r
"""Create the dataset in a temporary directory.
r
"""Create the dataset in a temporary directory.
The configuration passed to the dataset is populated to contain at least all parameters with default values.
For this the following order of precedence is used:
1. Parameters in :attr:`kwargs`.
2. Configuration in :attr:`config`.
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
4. Default parameters of the dataset.
Args:
Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
default configuration is used.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
creating the dataset.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
...
@@ -308,30 +356,33 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -308,30 +356,33 @@ class DatasetTestCase(unittest.TestCase):
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
for details.
"""
"""
default_config
=
self
.
_DEFAULT_CONFIG
.
copy
()
if
config
is
not
None
:
default_config
.
update
(
config
)
config
=
default_config
if
patch_checks
is
None
:
if
patch_checks
is
None
:
patch_checks
=
inject_fake_data
patch_checks
=
inject_fake_data
special_kwargs
,
other_kwargs
=
self
.
_split_kwargs
(
kwargs
)
special_kwargs
,
other_kwargs
=
self
.
_split_kwargs
(
kwargs
)
complete_config
=
self
.
_KWARG_DEFAULTS
.
copy
()
if
self
.
DEFAULT_CONFIG
:
complete_config
.
update
(
self
.
DEFAULT_CONFIG
)
if
config
:
complete_config
.
update
(
config
)
if
other_kwargs
:
complete_config
.
update
(
other_kwargs
)
if
"download"
in
self
.
_HAS_SPECIAL_KWARG
and
special_kwargs
.
get
(
"download"
,
False
):
if
"download"
in
self
.
_HAS_SPECIAL_KWARG
and
special_kwargs
.
get
(
"download"
,
False
):
# override download param to False param if its default is truthy
# override download param to False param if its default is truthy
special_kwargs
[
"download"
]
=
False
special_kwargs
[
"download"
]
=
False
config
.
update
(
other_kwargs
)
patchers
=
self
.
_patch_download_extract
()
patchers
=
self
.
_patch_download_extract
()
if
patch_checks
:
if
patch_checks
:
patchers
.
update
(
self
.
_patch_checks
())
patchers
.
update
(
self
.
_patch_checks
())
with
get_tmp_dir
()
as
tmpdir
:
with
get_tmp_dir
()
as
tmpdir
:
args
=
self
.
dataset_args
(
tmpdir
,
config
)
args
=
self
.
dataset_args
(
tmpdir
,
complete_
config
)
info
=
self
.
_inject_fake_data
(
tmpdir
,
config
)
if
inject_fake_data
else
None
info
=
self
.
_inject_fake_data
(
tmpdir
,
complete_
config
)
if
inject_fake_data
else
None
with
self
.
_maybe_apply_patches
(
patchers
),
disable_console_output
():
with
self
.
_maybe_apply_patches
(
patchers
),
disable_console_output
():
dataset
=
self
.
DATASET_CLASS
(
*
args
,
**
config
,
**
special_kwargs
)
dataset
=
self
.
DATASET_CLASS
(
*
args
,
**
complete_
config
,
**
special_kwargs
)
yield
dataset
,
info
yield
dataset
,
info
...
@@ -357,26 +408,69 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -357,26 +408,69 @@ class DatasetTestCase(unittest.TestCase):
@
classmethod
@
classmethod
def
_populate_private_class_attributes
(
cls
):
def
_populate_private_class_attributes
(
cls
):
argspec
=
inspect
.
getfullargspec
(
cls
.
DATASET_CLASS
.
__init__
)
defaults
=
[]
for
cls_
in
cls
.
DATASET_CLASS
.
__mro__
:
if
cls_
is
torchvision
.
datasets
.
VisionDataset
:
break
cls
.
_DEFAULT_CONFIG
=
{
argspec
=
inspect
.
getfullargspec
(
cls_
.
__init__
)
kwarg
:
default
for
kwarg
,
default
in
zip
(
argspec
.
args
[
-
len
(
argspec
.
defaults
):],
argspec
.
defaults
)
if
kwarg
not
in
cls
.
_SPECIAL_KWARGS
}
cls
.
_HAS_SPECIAL_KWARG
=
{
name
for
name
in
cls
.
_SPECIAL_KWARGS
if
name
in
argspec
.
args
}
if
not
argspec
.
defaults
:
continue
defaults
.
append
(
{
kwarg
:
default
for
kwarg
,
default
in
zip
(
argspec
.
args
[
-
len
(
argspec
.
defaults
):],
argspec
.
defaults
)}
)
if
not
argspec
.
varkw
:
break
kwarg_defaults
=
dict
()
for
config
in
reversed
(
defaults
):
kwarg_defaults
.
update
(
config
)
has_special_kwargs
=
set
()
for
name
in
cls
.
_SPECIAL_KWARGS
:
if
name
not
in
kwarg_defaults
:
continue
del
kwarg_defaults
[
name
]
has_special_kwargs
.
add
(
name
)
cls
.
_KWARG_DEFAULTS
=
kwarg_defaults
cls
.
_HAS_SPECIAL_KWARG
=
has_special_kwargs
@
classmethod
@
classmethod
def
_process_optional_public_class_attributes
(
cls
):
def
_process_optional_public_class_attributes
(
cls
):
if
cls
.
REQUIRED_PACKAGES
is
not
None
:
def
check_config
(
config
,
name
):
try
:
special_kwargs
=
tuple
(
f
"'
{
name
}
'"
for
name
in
cls
.
_SPECIAL_KWARGS
if
name
in
config
)
if
special_kwargs
:
raise
UsageError
(
f
"
{
name
}
contains a value for the parameter(s)
{
', '
.
join
(
special_kwargs
)
}
. "
f
"These are handled separately by the test case and should not be set here. "
f
"If you need to test some custom behavior regarding these parameters, "
f
"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
)
if
cls
.
DEFAULT_CONFIG
is
not
None
:
check_config
(
cls
.
DEFAULT_CONFIG
,
"DEFAULT_CONFIG"
)
if
cls
.
ADDITIONAL_CONFIGS
is
not
None
:
for
idx
,
config
in
enumerate
(
cls
.
ADDITIONAL_CONFIGS
):
check_config
(
config
,
f
"CONFIGS[
{
idx
}
]"
)
if
cls
.
REQUIRED_PACKAGES
:
missing_pkgs
=
[]
for
pkg
in
cls
.
REQUIRED_PACKAGES
:
for
pkg
in
cls
.
REQUIRED_PACKAGES
:
try
:
importlib
.
import_module
(
pkg
)
importlib
.
import_module
(
pkg
)
except
ImportError
as
error
:
except
ImportError
:
missing_pkgs
.
append
(
f
"'
{
pkg
}
'"
)
if
missing_pkgs
:
raise
unittest
.
SkipTest
(
raise
unittest
.
SkipTest
(
f
"The package
'
{
error
.
name
}
' is required to load the dataset '
{
cls
.
DATASET_CLASS
.
__name__
}
' but is
"
f
"The package
(s)
{
', '
.
join
(
missing_pkgs
)
}
are required to load the dataset
"
f
"not installed."
f
"
'
{
cls
.
DATASET_CLASS
.
__name__
}
', but are
not installed."
)
)
def
_split_kwargs
(
self
,
kwargs
):
def
_split_kwargs
(
self
,
kwargs
):
...
...
test/test_datasets.py
View file @
09600b90
...
@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
Caltech101
DATASET_CLASS
=
datasets
.
Caltech101
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
int
,
np
.
ndarray
,
tuple
))
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
int
,
np
.
ndarray
,
tuple
))
CONFIGS
=
datasets_utils
.
combinations_grid
(
target_type
=
(
"category"
,
"annotation"
,
[
"category"
,
"annotation"
]))
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
target_type
=
(
"category"
,
"annotation"
,
[
"category"
,
"annotation"
])
)
REQUIRED_PACKAGES
=
(
"scipy"
,)
REQUIRED_PACKAGES
=
(
"scipy"
,)
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
...
@@ -466,7 +468,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -466,7 +468,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class
WIDERFaceTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
WIDERFaceTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
WIDERFace
DATASET_CLASS
=
datasets
.
WIDERFace
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
dict
,
type
(
None
)))
# test split returns None as target
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
dict
,
type
(
None
)))
# test split returns None as target
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
,
'test'
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
,
'test'
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
widerface_dir
=
pathlib
.
Path
(
tmpdir
)
/
'widerface'
widerface_dir
=
pathlib
.
Path
(
tmpdir
)
/
'widerface'
...
@@ -521,7 +523,7 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -521,7 +523,7 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
ImageNet
DATASET_CLASS
=
datasets
.
ImageNet
REQUIRED_PACKAGES
=
(
'scipy'
,)
REQUIRED_PACKAGES
=
(
'scipy'
,)
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
'train'
,
'val'
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
tmpdir
=
pathlib
.
Path
(
tmpdir
)
tmpdir
=
pathlib
.
Path
(
tmpdir
)
...
@@ -551,7 +553,7 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -551,7 +553,7 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
CIFAR10
DATASET_CLASS
=
datasets
.
CIFAR10
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
_VERSION_CONFIG
=
dict
(
_VERSION_CONFIG
=
dict
(
base_folder
=
"cifar-10-batches-py"
,
base_folder
=
"cifar-10-batches-py"
,
...
@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
CelebA
DATASET_CLASS
=
datasets
.
CelebA
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
torch
.
Tensor
,
int
,
tuple
,
type
(
None
)))
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
(
torch
.
Tensor
,
int
,
tuple
,
type
(
None
)))
CONFIGS
=
datasets_utils
.
combinations_grid
(
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"valid"
,
"test"
,
"all"
),
split
=
(
"train"
,
"valid"
,
"test"
,
"all"
),
target_type
=
(
"attr"
,
"identity"
,
"bbox"
,
"landmarks"
,
[
"attr"
,
"identity"
]),
target_type
=
(
"attr"
,
"identity"
,
"bbox"
,
"landmarks"
,
[
"attr"
,
"identity"
]),
)
)
...
@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
VOCSegmentation
DATASET_CLASS
=
datasets
.
VOCSegmentation
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
)
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
)
CONFIGS
=
(
ADDITIONAL_
CONFIGS
=
(
*
datasets_utils
.
combinations_grid
(
*
datasets_utils
.
combinations_grid
(
year
=
[
f
"20
{
year
:
02
d
}
"
for
year
in
range
(
7
,
13
)],
image_set
=
(
"train"
,
"val"
,
"trainval"
)
year
=
[
f
"20
{
year
:
02
d
}
"
for
year
in
range
(
7
,
13
)],
image_set
=
(
"train"
,
"val"
,
"trainval"
)
),
),
...
@@ -929,7 +931,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
...
@@ -929,7 +931,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
class
UCF101TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
class
UCF101TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
DATASET_CLASS
=
datasets
.
UCF101
DATASET_CLASS
=
datasets
.
UCF101
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
_VIDEO_FOLDER
=
"videos"
_VIDEO_FOLDER
=
"videos"
_ANNOTATIONS_FOLDER
=
"annotations"
_ANNOTATIONS_FOLDER
=
"annotations"
...
@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS
=
datasets
.
LSUN
DATASET_CLASS
=
datasets
.
LSUN
REQUIRED_PACKAGES
=
(
"lmdb"
,)
REQUIRED_PACKAGES
=
(
"lmdb"
,)
CONFIGS
=
datasets_utils
.
combinations_grid
(
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
classes
=
(
"train"
,
"test"
,
"val"
,
[
"bedroom_train"
,
"church_outdoor_train"
])
classes
=
(
"train"
,
"test"
,
"val"
,
[
"bedroom_train"
,
"church_outdoor_train"
])
)
)
...
@@ -1097,7 +1099,7 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
...
@@ -1097,7 +1099,7 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
DATASET_CLASS
=
datasets
.
HMDB51
DATASET_CLASS
=
datasets
.
HMDB51
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
fold
=
(
1
,
2
,
3
),
train
=
(
True
,
False
))
_VIDEO_FOLDER
=
"videos"
_VIDEO_FOLDER
=
"videos"
_SPLITS_FOLDER
=
"splits"
_SPLITS_FOLDER
=
"splits"
...
@@ -1157,7 +1159,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
...
@@ -1157,7 +1159,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
class
OmniglotTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
OmniglotTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Omniglot
DATASET_CLASS
=
datasets
.
Omniglot
CONFIGS
=
datasets_utils
.
combinations_grid
(
background
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
background
=
(
True
,
False
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
target_folder
=
(
target_folder
=
(
...
@@ -1237,7 +1239,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1237,7 +1239,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
class
USPSTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
USPSTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
USPS
DATASET_CLASS
=
datasets
.
USPS
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
num_images
=
2
if
config
[
"train"
]
else
1
num_images
=
2
if
config
[
"train"
]
else
1
...
@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES
=
(
"scipy.io"
,
"scipy.sparse"
)
REQUIRED_PACKAGES
=
(
"scipy.io"
,
"scipy.sparse"
)
CONFIGS
=
datasets_utils
.
combinations_grid
(
ADDITIONAL_
CONFIGS
=
datasets_utils
.
combinations_grid
(
image_set
=
(
"train"
,
"val"
,
"train_noval"
),
mode
=
(
"boundaries"
,
"segmentation"
)
image_set
=
(
"train"
,
"val"
,
"train_noval"
),
mode
=
(
"boundaries"
,
"segmentation"
)
)
)
...
@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
_TRAIN_FEATURE_TYPES
=
(
torch
.
Tensor
,)
_TRAIN_FEATURE_TYPES
=
(
torch
.
Tensor
,)
_TEST_FEATURE_TYPES
=
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
)
_TEST_FEATURE_TYPES
=
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
)
CONFIGS
=
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
datasets_utils
.
combinations_grid
(
train
=
(
True
,
False
))
_NAME
=
"liberty"
_NAME
=
"liberty"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment