Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
0fd8347d
Commit
0fd8347d
authored
Jan 08, 2023
by
unknown
Browse files
添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark
parent
cc567e9e
Changes
839
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1887 additions
and
14 deletions
+1887
-14
openmmlab_test/mmclassification-0.24.1/setup.cfg
openmmlab_test/mmclassification-0.24.1/setup.cfg
+23
-0
openmmlab_test/mmclassification-0.24.1/setup.py
openmmlab_test/mmclassification-0.24.1/setup.py
+194
-0
openmmlab_test/mmclassification-0.24.1/sing_test.sh
openmmlab_test/mmclassification-0.24.1/sing_test.sh
+1
-1
openmmlab_test/mmclassification-0.24.1/single_process.sh
openmmlab_test/mmclassification-0.24.1/single_process.sh
+28
-0
openmmlab_test/mmclassification-0.24.1/tests/data/color.jpg
openmmlab_test/mmclassification-0.24.1/tests/data/color.jpg
+0
-0
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/a/1.JPG
...b_test/mmclassification-0.24.1/tests/data/dataset/a/1.JPG
+0
-0
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/ann.txt
...b_test/mmclassification-0.24.1/tests/data/dataset/ann.txt
+3
-0
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/b/2.jpeg
..._test/mmclassification-0.24.1/tests/data/dataset/b/2.jpeg
+0
-0
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/b/subb/3.jpg
...t/mmclassification-0.24.1/tests/data/dataset/b/subb/3.jpg
+0
-0
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/classes.txt
...st/mmclassification-0.24.1/tests/data/dataset/classes.txt
+2
-0
openmmlab_test/mmclassification-0.24.1/tests/data/gray.jpg
openmmlab_test/mmclassification-0.24.1/tests/data/gray.jpg
+0
-0
openmmlab_test/mmclassification-0.24.1/tests/data/retinanet.py
...mlab_test/mmclassification-0.24.1/tests/data/retinanet.py
+83
-0
openmmlab_test/mmclassification-0.24.1/tests/data/test.logjson
...mlab_test/mmclassification-0.24.1/tests/data/test.logjson
+10
-0
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_builder.py
...t/mmclassification-0.24.1/tests/test_data/test_builder.py
+272
-0
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_common.py
...ation-0.24.1/tests/test_data/test_datasets/test_common.py
+911
-0
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_dataset_utils.py
....24.1/tests/test_data/test_datasets/test_dataset_utils.py
+22
-0
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_dataset_wrapper.py
...4.1/tests/test_data/test_datasets/test_dataset_wrapper.py
+192
-0
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_sampler.py
...tion-0.24.1/tests/test_data/test_datasets/test_sampler.py
+53
-0
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_pipelines/test_auto_augment.py
....24.1/tests/test_data/test_pipelines/test_auto_augment.py
+91
-12
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_pipelines/test_loading.py
...ion-0.24.1/tests/test_data/test_pipelines/test_loading.py
+2
-1
No files found.
Too many changes to show.
To preserve performance only
839 of 839+
files are displayed.
Plain diff
Email patch
openmmlab_test/mmclassification-
speed-benchmark
/setup.cfg
→
openmmlab_test/mmclassification-
0.24.1
/setup.cfg
View file @
0fd8347d
...
...
@@ -12,8 +12,12 @@ split_before_expression_after_opening_paren = true
[isort]
line_length = 79
multi_line_output = 0
known
_standard_library = pkg_resources,setuptools
extra
_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,cv2,matplotlib,mmcv,numpy,onnxruntime,pytest,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = patten,confectionary,nd,ty,formating,dows
openmmlab_test/mmclassification-0.24.1/setup.py
0 → 100644
View file @
0fd8347d
import
os
import
os.path
as
osp
import
shutil
import
sys
import
warnings
from
setuptools
import
find_packages
,
setup
def
readme
():
with
open
(
'README.md'
,
encoding
=
'utf-8'
)
as
f
:
content
=
f
.
read
()
return
content
def
get_version
():
version_file
=
'mmcls/version.py'
with
open
(
version_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
return
locals
()[
'__version__'
]
def
parse_requirements
(
fname
=
'requirements.txt'
,
with_version
=
True
):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
fname (str): path to requirements file
with_version (bool, default=False): if True include version specs
Returns:
List[str]: list of requirements items
CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
import
re
import
sys
from
os.path
import
exists
require_fpath
=
fname
def
parse_line
(
line
):
"""Parse information from a line in a requirements text file."""
if
line
.
startswith
(
'-r '
):
# Allow specifying requirements in other files
target
=
line
.
split
(
' '
)[
1
]
for
info
in
parse_require_file
(
target
):
yield
info
else
:
info
=
{
'line'
:
line
}
if
line
.
startswith
(
'-e '
):
info
[
'package'
]
=
line
.
split
(
'#egg='
)[
1
]
else
:
# Remove versioning from the package
pat
=
'('
+
'|'
.
join
([
'>='
,
'=='
,
'>'
])
+
')'
parts
=
re
.
split
(
pat
,
line
,
maxsplit
=
1
)
parts
=
[
p
.
strip
()
for
p
in
parts
]
info
[
'package'
]
=
parts
[
0
]
if
len
(
parts
)
>
1
:
op
,
rest
=
parts
[
1
:]
if
';'
in
rest
:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version
,
platform_deps
=
map
(
str
.
strip
,
rest
.
split
(
';'
))
info
[
'platform_deps'
]
=
platform_deps
else
:
version
=
rest
# NOQA
if
'--'
in
version
:
# the `extras_require` doesn't accept options.
version
=
version
.
split
(
'--'
)[
0
].
strip
()
info
[
'version'
]
=
(
op
,
version
)
yield
info
def
parse_require_file
(
fpath
):
with
open
(
fpath
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
line
=
line
.
strip
()
if
line
and
not
line
.
startswith
(
'#'
):
for
info
in
parse_line
(
line
):
yield
info
def
gen_packages_items
():
if
exists
(
require_fpath
):
for
info
in
parse_require_file
(
require_fpath
):
parts
=
[
info
[
'package'
]]
if
with_version
and
'version'
in
info
:
parts
.
extend
(
info
[
'version'
])
if
not
sys
.
version
.
startswith
(
'3.4'
):
# apparently package_deps are broken in 3.4
platform_deps
=
info
.
get
(
'platform_deps'
)
if
platform_deps
is
not
None
:
parts
.
append
(
';'
+
platform_deps
)
item
=
''
.
join
(
parts
)
yield
item
packages
=
list
(
gen_packages_items
())
return
packages
def
add_mim_extension
():
"""Add extra files that are required to support MIM into the package.
These files will be added by creating a symlink to the originals if the
package is installed in `editable` mode (e.g. pip install -e .), or by
copying from the originals otherwise.
"""
# parse installment mode
if
'develop'
in
sys
.
argv
:
# installed by `pip install -e .`
mode
=
'symlink'
elif
'sdist'
in
sys
.
argv
or
'bdist_wheel'
in
sys
.
argv
:
# installed by `pip install .`
# or create source distribution by `python setup.py sdist`
mode
=
'copy'
else
:
return
filenames
=
[
'tools'
,
'configs'
,
'model-index.yml'
]
repo_path
=
osp
.
dirname
(
__file__
)
mim_path
=
osp
.
join
(
repo_path
,
'mmcls'
,
'.mim'
)
os
.
makedirs
(
mim_path
,
exist_ok
=
True
)
for
filename
in
filenames
:
if
osp
.
exists
(
filename
):
src_path
=
osp
.
join
(
repo_path
,
filename
)
tar_path
=
osp
.
join
(
mim_path
,
filename
)
if
osp
.
isfile
(
tar_path
)
or
osp
.
islink
(
tar_path
):
os
.
remove
(
tar_path
)
elif
osp
.
isdir
(
tar_path
):
shutil
.
rmtree
(
tar_path
)
if
mode
==
'symlink'
:
src_relpath
=
osp
.
relpath
(
src_path
,
osp
.
dirname
(
tar_path
))
try
:
os
.
symlink
(
src_relpath
,
tar_path
)
except
OSError
:
# Creating a symbolic link on windows may raise an
# `OSError: [WinError 1314]` due to privilege. If
# the error happens, the src file will be copied
mode
=
'copy'
warnings
.
warn
(
f
'Failed to create a symbolic link for
{
src_relpath
}
, '
f
'and it will be copied to
{
tar_path
}
'
)
else
:
continue
if
mode
==
'copy'
:
if
osp
.
isfile
(
src_path
):
shutil
.
copyfile
(
src_path
,
tar_path
)
elif
osp
.
isdir
(
src_path
):
shutil
.
copytree
(
src_path
,
tar_path
)
else
:
warnings
.
warn
(
f
'Cannot copy file
{
src_path
}
.'
)
else
:
raise
ValueError
(
f
'Invalid mode
{
mode
}
'
)
if
__name__
==
'__main__'
:
add_mim_extension
()
setup
(
name
=
'mmcls'
,
version
=
get_version
(),
description
=
'OpenMMLab Image Classification Toolbox and Benchmark'
,
long_description
=
readme
(),
long_description_content_type
=
'text/markdown'
,
keywords
=
'computer vision, image classification'
,
packages
=
find_packages
(
exclude
=
(
'configs'
,
'tools'
,
'demo'
)),
include_package_data
=
True
,
classifiers
=
[
'Development Status :: 4 - Beta'
,
'License :: OSI Approved :: Apache Software License'
,
'Operating System :: OS Independent'
,
'Programming Language :: Python :: 3'
,
'Programming Language :: Python :: 3.6'
,
'Programming Language :: Python :: 3.7'
,
'Programming Language :: Python :: 3.8'
,
'Programming Language :: Python :: 3.9'
,
'Topic :: Scientific/Engineering :: Artificial Intelligence'
,
],
url
=
'https://github.com/open-mmlab/mmclassification'
,
author
=
'MMClassification Contributors'
,
author_email
=
'openmmlab@gmail.com'
,
license
=
'Apache License 2.0'
,
install_requires
=
parse_requirements
(
'requirements/runtime.txt'
),
extras_require
=
{
'all'
:
parse_requirements
(
'requirements.txt'
),
'tests'
:
parse_requirements
(
'requirements/tests.txt'
),
'optional'
:
parse_requirements
(
'requirements/optional.txt'
),
'mim'
:
parse_requirements
(
'requirements/mminstall.txt'
),
},
zip_safe
=
False
)
openmmlab_test/mmclassification-
speed-benchmark
/sing_test.sh
→
openmmlab_test/mmclassification-
0.24.1
/sing_test.sh
View file @
0fd8347d
#!/bin/bash
export
HIP_VISIBLE_DEVICES
=
3
export
MIOPEN_FIND_MODE
=
3
export
MIOPEN_FIND_MODE
=
1
my_config
=
$1
python3 tools/train.py
$my_config
openmmlab_test/mmclassification-0.24.1/single_process.sh
0 → 100644
View file @
0fd8347d
#!/bin/bash
lrank
=
$OMPI_COMM_WORLD_LOCAL_RANK
comm_rank
=
$OMPI_COMM_WORLD_RANK
comm_size
=
$OMPI_COMM_WORLD_SIZE
export
MASTER_ADDR
=
${
1
}
APP
=
"python3 tools/train.py configs/resnet/resnet18_b32x8_imagenet.py --launcher mpi"
case
${
lrank
}
in
[
0]
)
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
;;
[
1]
)
numactl
--cpunodebind
=
1
--membind
=
1
${
APP
}
;;
[
2]
)
numactl
--cpunodebind
=
2
--membind
=
2
${
APP
}
;;
[
3]
)
numactl
--cpunodebind
=
3
--membind
=
3
${
APP
}
;;
[
4]
)
numactl
--cpunodebind
=
4
--membind
=
4
${
APP
}
;;
esac
openmmlab_test/mmclassification-
speed-benchmark
/tests/data/color.jpg
→
openmmlab_test/mmclassification-
0.24.1
/tests/data/color.jpg
View file @
0fd8347d
File moved
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/a/1.JPG
0 → 100644
View file @
0fd8347d
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/ann.txt
0 → 100644
View file @
0fd8347d
a/1.JPG 0
b/2.jpeg 1
b/subb/2.jpeg 1
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/b/2.jpeg
0 → 100644
View file @
0fd8347d
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/b/subb/3.jpg
0 → 100644
View file @
0fd8347d
openmmlab_test/mmclassification-0.24.1/tests/data/dataset/classes.txt
0 → 100644
View file @
0fd8347d
bus
car
openmmlab_test/mmclassification-
speed-benchmark
/tests/data/gray.jpg
→
openmmlab_test/mmclassification-
0.24.1
/tests/data/gray.jpg
View file @
0fd8347d
File moved
openmmlab_test/mmclassification-0.24.1/tests/data/retinanet.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
# small RetinaNet
num_classes
=
3
# model settings
model
=
dict
(
type
=
'RetinaNet'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
norm_eval
=
True
,
style
=
'pytorch'
,
init_cfg
=
dict
(
type
=
'Pretrained'
,
checkpoint
=
'torchvision://resnet50'
)),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
1
,
add_extra_convs
=
'on_input'
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'RetinaHead'
,
num_classes
=
num_classes
,
in_channels
=
256
,
stacked_convs
=
1
,
feat_channels
=
256
,
anchor_generator
=
dict
(
type
=
'AnchorGenerator'
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
ratios
=
[
0.5
,
1.0
,
2.0
],
strides
=
[
8
,
16
,
32
,
64
,
128
]),
bbox_coder
=
dict
(
type
=
'DeltaXYWHBBoxCoder'
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]),
loss_cls
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
1.0
)),
# model training and testing settings
train_cfg
=
dict
(
assigner
=
dict
(
type
=
'MaxIoUAssigner'
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0
,
ignore_iof_thr
=-
1
),
allowed_border
=-
1
,
pos_weight
=-
1
,
debug
=
False
),
test_cfg
=
dict
(
nms_pre
=
1000
,
min_bbox_size
=
0
,
score_thr
=
0.05
,
nms
=
dict
(
type
=
'nms'
,
iou_threshold
=
0.5
),
max_per_img
=
100
))
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
test
=
dict
(
pipeline
=
test_pipeline
))
openmmlab_test/mmclassification-0.24.1/tests/data/test.logjson
0 → 100644
View file @
0fd8347d
{"a": "b"}
{"mode": "train", "epoch": 1, "iter": 10, "lr": 0.01309, "memory": 0, "data_time": 0.0072, "time": 0.00727}
{"mode": "train", "epoch": 1, "iter": 20, "lr": 0.02764, "memory": 0, "data_time": 0.00044, "time": 0.00046}
{"mode": "train", "epoch": 1, "iter": 30, "lr": 0.04218, "memory": 0, "data_time": 0.00028, "time": 0.0003}
{"mode": "train", "epoch": 1, "iter": 40, "lr": 0.05673, "memory": 0, "data_time": 0.00027, "time": 0.00029}
{"mode": "train", "epoch": 2, "iter": 10, "lr": 0.17309, "memory": 0, "data_time": 0.00048, "time": 0.0005}
{"mode": "train", "epoch": 2, "iter": 20, "lr": 0.18763, "memory": 0, "data_time": 0.00038, "time": 0.0004}
{"mode": "train", "epoch": 2, "iter": 30, "lr": 0.20218, "memory": 0, "data_time": 0.00037, "time": 0.00039}
{"mode": "train", "epoch": 3, "iter": 10, "lr": 0.33305, "memory": 0, "data_time": 0.00045, "time": 0.00046}
{"mode": "train", "epoch": 3, "iter": 20, "lr": 0.34759, "memory": 0, "data_time": 0.0003, "time": 0.00032}
\ No newline at end of file
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_builder.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
copy
import
deepcopy
from
unittest.mock
import
patch
import
torch
from
mmcv.utils
import
digit_version
from
mmcls.datasets
import
ImageNet
,
build_dataloader
,
build_dataset
from
mmcls.datasets.dataset_wrappers
import
(
ClassBalancedDataset
,
ConcatDataset
,
KFoldDataset
,
RepeatDataset
)
class
TestDataloaderBuilder
():
@
classmethod
def
setup_class
(
cls
):
cls
.
data
=
list
(
range
(
20
))
cls
.
samples_per_gpu
=
5
cls
.
workers_per_gpu
=
1
@
patch
(
'mmcls.datasets.builder.get_dist_info'
,
return_value
=
(
0
,
1
))
def
test_single_gpu
(
self
,
_
):
common_cfg
=
dict
(
dataset
=
self
.
data
,
samples_per_gpu
=
self
.
samples_per_gpu
,
workers_per_gpu
=
self
.
workers_per_gpu
,
dist
=
False
)
# Test default config
dataloader
=
build_dataloader
(
**
common_cfg
)
if
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.8.0'
):
assert
dataloader
.
persistent_workers
elif
hasattr
(
dataloader
,
'persistent_workers'
):
assert
not
dataloader
.
persistent_workers
assert
dataloader
.
batch_size
==
self
.
samples_per_gpu
assert
dataloader
.
num_workers
==
self
.
workers_per_gpu
assert
not
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
torch
.
tensor
(
self
.
data
))
# Test without shuffle
dataloader
=
build_dataloader
(
**
common_cfg
,
shuffle
=
False
)
assert
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
torch
.
tensor
(
self
.
data
))
# Test with custom sampler_cfg
dataloader
=
build_dataloader
(
**
common_cfg
,
sampler_cfg
=
dict
(
type
=
'RepeatAugSampler'
,
selected_round
=
0
),
shuffle
=
False
)
expect
=
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
4
,
4
,
4
,
5
,
5
,
5
,
6
,
6
]
assert
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
torch
.
tensor
(
expect
))
@
patch
(
'mmcls.datasets.builder.get_dist_info'
,
return_value
=
(
0
,
1
))
def
test_multi_gpu
(
self
,
_
):
common_cfg
=
dict
(
dataset
=
self
.
data
,
samples_per_gpu
=
self
.
samples_per_gpu
,
workers_per_gpu
=
self
.
workers_per_gpu
,
num_gpus
=
2
,
dist
=
False
)
# Test default config
dataloader
=
build_dataloader
(
**
common_cfg
)
if
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.8.0'
):
assert
dataloader
.
persistent_workers
elif
hasattr
(
dataloader
,
'persistent_workers'
):
assert
not
dataloader
.
persistent_workers
assert
dataloader
.
batch_size
==
self
.
samples_per_gpu
*
2
assert
dataloader
.
num_workers
==
self
.
workers_per_gpu
*
2
assert
not
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
torch
.
tensor
(
self
.
data
))
# Test without shuffle
dataloader
=
build_dataloader
(
**
common_cfg
,
shuffle
=
False
)
assert
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
torch
.
tensor
(
self
.
data
))
# Test with custom sampler_cfg
dataloader
=
build_dataloader
(
**
common_cfg
,
sampler_cfg
=
dict
(
type
=
'RepeatAugSampler'
,
selected_round
=
0
),
shuffle
=
False
)
expect
=
torch
.
tensor
(
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
4
,
4
,
4
,
5
,
5
,
5
,
6
,
6
])
assert
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
expect
)
@
patch
(
'mmcls.datasets.builder.get_dist_info'
,
return_value
=
(
1
,
2
))
def
test_distributed
(
self
,
_
):
common_cfg
=
dict
(
dataset
=
self
.
data
,
samples_per_gpu
=
self
.
samples_per_gpu
,
workers_per_gpu
=
self
.
workers_per_gpu
,
num_gpus
=
2
,
# num_gpus will be ignored in distributed environment.
dist
=
True
)
# Test default config
dataloader
=
build_dataloader
(
**
common_cfg
)
if
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.8.0'
):
assert
dataloader
.
persistent_workers
elif
hasattr
(
dataloader
,
'persistent_workers'
):
assert
not
dataloader
.
persistent_workers
assert
dataloader
.
batch_size
==
self
.
samples_per_gpu
assert
dataloader
.
num_workers
==
self
.
workers_per_gpu
non_expect
=
torch
.
tensor
(
self
.
data
[
1
::
2
])
assert
not
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
non_expect
)
# Test without shuffle
dataloader
=
build_dataloader
(
**
common_cfg
,
shuffle
=
False
)
expect
=
torch
.
tensor
(
self
.
data
[
1
::
2
])
assert
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
expect
)
# Test with custom sampler_cfg
dataloader
=
build_dataloader
(
**
common_cfg
,
sampler_cfg
=
dict
(
type
=
'RepeatAugSampler'
,
selected_round
=
0
),
shuffle
=
False
)
expect
=
torch
.
tensor
(
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
,
4
,
4
,
4
,
5
,
5
,
5
,
6
,
6
][
1
::
2
])
assert
all
(
torch
.
cat
(
list
(
iter
(
dataloader
)))
==
expect
)
class
TestDatasetBuilder
():
@
classmethod
def
setup_class
(
cls
):
data_prefix
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../data/dataset'
)
cls
.
dataset_cfg
=
dict
(
type
=
'ImageNet'
,
data_prefix
=
data_prefix
,
ann_file
=
osp
.
join
(
data_prefix
,
'ann.txt'
),
pipeline
=
[],
test_mode
=
False
,
)
def
test_normal_dataset
(
self
):
# Test build
dataset
=
build_dataset
(
self
.
dataset_cfg
)
assert
isinstance
(
dataset
,
ImageNet
)
assert
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
# Test default_args
dataset
=
build_dataset
(
self
.
dataset_cfg
,
{
'test_mode'
:
True
})
assert
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
cp_cfg
=
deepcopy
(
self
.
dataset_cfg
)
cp_cfg
.
pop
(
'test_mode'
)
dataset
=
build_dataset
(
cp_cfg
,
{
'test_mode'
:
True
})
assert
dataset
.
test_mode
def
test_concat_dataset
(
self
):
# Test build
dataset
=
build_dataset
([
self
.
dataset_cfg
,
self
.
dataset_cfg
])
assert
isinstance
(
dataset
,
ConcatDataset
)
assert
dataset
.
datasets
[
0
].
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
# Test default_args
dataset
=
build_dataset
([
self
.
dataset_cfg
,
self
.
dataset_cfg
],
{
'test_mode'
:
True
})
assert
dataset
.
datasets
[
0
].
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
cp_cfg
=
deepcopy
(
self
.
dataset_cfg
)
cp_cfg
.
pop
(
'test_mode'
)
dataset
=
build_dataset
([
cp_cfg
,
cp_cfg
],
{
'test_mode'
:
True
})
assert
dataset
.
datasets
[
0
].
test_mode
def
test_repeat_dataset
(
self
):
# Test build
dataset
=
build_dataset
(
dict
(
type
=
'RepeatDataset'
,
dataset
=
self
.
dataset_cfg
,
times
=
3
))
assert
isinstance
(
dataset
,
RepeatDataset
)
assert
dataset
.
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
# Test default_args
dataset
=
build_dataset
(
dict
(
type
=
'RepeatDataset'
,
dataset
=
self
.
dataset_cfg
,
times
=
3
),
{
'test_mode'
:
True
})
assert
dataset
.
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
cp_cfg
=
deepcopy
(
self
.
dataset_cfg
)
cp_cfg
.
pop
(
'test_mode'
)
dataset
=
build_dataset
(
dict
(
type
=
'RepeatDataset'
,
dataset
=
cp_cfg
,
times
=
3
),
{
'test_mode'
:
True
})
assert
dataset
.
dataset
.
test_mode
def
test_class_balance_dataset
(
self
):
# Test build
dataset
=
build_dataset
(
dict
(
type
=
'ClassBalancedDataset'
,
dataset
=
self
.
dataset_cfg
,
oversample_thr
=
1.
,
))
assert
isinstance
(
dataset
,
ClassBalancedDataset
)
assert
dataset
.
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
# Test default_args
dataset
=
build_dataset
(
dict
(
type
=
'ClassBalancedDataset'
,
dataset
=
self
.
dataset_cfg
,
oversample_thr
=
1.
,
),
{
'test_mode'
:
True
})
assert
dataset
.
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
cp_cfg
=
deepcopy
(
self
.
dataset_cfg
)
cp_cfg
.
pop
(
'test_mode'
)
dataset
=
build_dataset
(
dict
(
type
=
'ClassBalancedDataset'
,
dataset
=
cp_cfg
,
oversample_thr
=
1.
,
),
{
'test_mode'
:
True
})
assert
dataset
.
dataset
.
test_mode
def
test_kfold_dataset
(
self
):
# Test build
dataset
=
build_dataset
(
dict
(
type
=
'KFoldDataset'
,
dataset
=
self
.
dataset_cfg
,
fold
=
0
,
num_splits
=
5
,
test_mode
=
False
,
))
assert
isinstance
(
dataset
,
KFoldDataset
)
assert
not
dataset
.
test_mode
assert
dataset
.
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
# Test default_args
dataset
=
build_dataset
(
dict
(
type
=
'KFoldDataset'
,
dataset
=
self
.
dataset_cfg
,
fold
=
0
,
num_splits
=
5
,
test_mode
=
False
,
),
default_args
=
{
'test_mode'
:
True
,
'classes'
:
[
1
,
2
,
3
]
})
assert
not
dataset
.
test_mode
assert
dataset
.
dataset
.
test_mode
==
self
.
dataset_cfg
[
'test_mode'
]
assert
dataset
.
dataset
.
CLASSES
==
[
1
,
2
,
3
]
cp_cfg
=
deepcopy
(
self
.
dataset_cfg
)
cp_cfg
.
pop
(
'test_mode'
)
dataset
=
build_dataset
(
dict
(
type
=
'KFoldDataset'
,
dataset
=
self
.
dataset_cfg
,
fold
=
0
,
num_splits
=
5
,
),
default_args
=
{
'test_mode'
:
True
,
'classes'
:
[
1
,
2
,
3
]
})
# The test_mode in default_args will be passed to KFoldDataset
assert
dataset
.
test_mode
assert
not
dataset
.
dataset
.
test_mode
# Other default_args will be passed to child dataset.
assert
dataset
.
dataset
.
CLASSES
==
[
1
,
2
,
3
]
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_common.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os.path
as
osp
import
pickle
import
tempfile
from
unittest
import
TestCase
from
unittest.mock
import
patch
import
numpy
as
np
import
torch
from
mmcls.datasets
import
DATASETS
from
mmcls.datasets
import
BaseDataset
as
_BaseDataset
from
mmcls.datasets
import
MultiLabelDataset
as
_MultiLabelDataset
ASSETS_ROOT
=
osp
.
abspath
(
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/dataset'
))
class
BaseDataset
(
_BaseDataset
):
def
load_annotations
(
self
):
pass
class
MultiLabelDataset
(
_MultiLabelDataset
):
def
load_annotations
(
self
):
pass
DATASETS
.
module_dict
[
'BaseDataset'
]
=
BaseDataset
DATASETS
.
module_dict
[
'MultiLabelDataset'
]
=
MultiLabelDataset
class
TestBaseDataset
(
TestCase
):
DATASET_TYPE
=
'BaseDataset'
DEFAULT_ARGS
=
dict
(
data_prefix
=
''
,
pipeline
=
[])
def
test_initialize
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
with
patch
.
object
(
dataset_class
,
'load_annotations'
):
# Test default behavior
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'classes'
:
None
,
'ann_file'
:
None
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
self
.
assertFalse
(
dataset
.
test_mode
)
self
.
assertIsNone
(
dataset
.
ann_file
)
# Test setting classes as a tuple
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'classes'
:
(
'bus'
,
'car'
)}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
(
'bus'
,
'car'
))
# Test setting classes as a tuple
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'classes'
:
[
'bus'
,
'car'
]}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
[
'bus'
,
'car'
])
# Test setting classes through a file
classes_file
=
osp
.
join
(
ASSETS_ROOT
,
'classes.txt'
)
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'classes'
:
classes_file
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
[
'bus'
,
'car'
])
self
.
assertEqual
(
dataset
.
class_to_idx
,
{
'bus'
:
0
,
'car'
:
1
})
# Test invalid classes
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'classes'
:
dict
(
classes
=
1
)}
with
self
.
assertRaisesRegex
(
ValueError
,
"type <class 'dict'>"
):
dataset_class
(
**
cfg
)
def
test_get_cat_ids
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
fake_ann
=
[
dict
(
img_prefix
=
''
,
img_info
=
dict
(),
gt_label
=
np
.
array
(
0
,
dtype
=
np
.
int64
))
]
with
patch
.
object
(
dataset_class
,
'load_annotations'
)
as
mock_load
:
mock_load
.
return_value
=
fake_ann
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
cat_ids
=
dataset
.
get_cat_ids
(
0
)
self
.
assertIsInstance
(
cat_ids
,
list
)
self
.
assertEqual
(
len
(
cat_ids
),
1
)
self
.
assertIsInstance
(
cat_ids
[
0
],
int
)
def
test_evaluate
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
fake_ann
=
[
dict
(
gt_label
=
np
.
array
(
0
,
dtype
=
np
.
int64
)),
dict
(
gt_label
=
np
.
array
(
0
,
dtype
=
np
.
int64
)),
dict
(
gt_label
=
np
.
array
(
1
,
dtype
=
np
.
int64
)),
dict
(
gt_label
=
np
.
array
(
2
,
dtype
=
np
.
int64
)),
dict
(
gt_label
=
np
.
array
(
1
,
dtype
=
np
.
int64
)),
dict
(
gt_label
=
np
.
array
(
0
,
dtype
=
np
.
int64
)),
]
with
patch
.
object
(
dataset_class
,
'load_annotations'
)
as
mock_load
:
mock_load
.
return_value
=
fake_ann
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
fake_results
=
np
.
array
([
[
0.7
,
0.0
,
0.3
],
[
0.5
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.1
],
[
0.0
,
0.0
,
1.0
],
[
0.0
,
0.0
,
1.0
],
[
0.0
,
0.0
,
1.0
],
])
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'support'
,
'accuracy'
],
metric_options
=
{
'topk'
:
1
})
# Test results
self
.
assertAlmostEqual
(
eval_results
[
'precision'
],
(
1
+
1
+
1
/
3
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results
[
'recall'
],
(
2
/
3
+
1
/
2
+
1
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results
[
'f1_score'
],
(
4
/
5
+
2
/
3
+
1
/
2
)
/
3
*
100.0
,
places
=
4
)
self
.
assertEqual
(
eval_results
[
'support'
],
6
)
self
.
assertAlmostEqual
(
eval_results
[
'accuracy'
],
4
/
6
*
100
,
places
=
4
)
# test indices
eval_results_
=
dataset
.
evaluate
(
fake_results
[:
5
],
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'support'
,
'accuracy'
],
metric_options
=
{
'topk'
:
1
},
indices
=
range
(
5
))
self
.
assertAlmostEqual
(
eval_results_
[
'precision'
],
(
1
+
1
+
1
/
2
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results_
[
'recall'
],
(
1
+
1
/
2
+
1
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results_
[
'f1_score'
],
(
1
+
2
/
3
+
2
/
3
)
/
3
*
100.0
,
places
=
4
)
self
.
assertEqual
(
eval_results_
[
'support'
],
5
)
self
.
assertAlmostEqual
(
eval_results_
[
'accuracy'
],
4
/
5
*
100
,
places
=
4
)
# test input as tensor
fake_results_tensor
=
torch
.
from_numpy
(
fake_results
)
eval_results_
=
dataset
.
evaluate
(
fake_results_tensor
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'support'
,
'accuracy'
],
metric_options
=
{
'topk'
:
1
})
assert
eval_results_
==
eval_results
# test thr
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'accuracy'
],
metric_options
=
{
'thrs'
:
0.6
,
'topk'
:
1
})
self
.
assertAlmostEqual
(
eval_results
[
'precision'
],
(
1
+
0
+
1
/
3
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results
[
'recall'
],
(
1
/
3
+
0
+
1
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results
[
'f1_score'
],
(
1
/
2
+
0
+
1
/
2
)
/
3
*
100.0
,
places
=
4
)
self
.
assertAlmostEqual
(
eval_results
[
'accuracy'
],
2
/
6
*
100
,
places
=
4
)
# thrs must be a number or tuple
with
self
.
assertRaises
(
TypeError
):
dataset
.
evaluate
(
fake_results
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'accuracy'
],
metric_options
=
{
'thrs'
:
'thr'
,
'topk'
:
1
})
# test topk and thr as tuple
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'accuracy'
],
metric_options
=
{
'thrs'
:
(
0.5
,
0.6
),
'topk'
:
(
1
,
2
)
})
self
.
assertEqual
(
{
'precision_thr_0.50'
,
'precision_thr_0.60'
,
'recall_thr_0.50'
,
'recall_thr_0.60'
,
'f1_score_thr_0.50'
,
'f1_score_thr_0.60'
,
'accuracy_top-1_thr_0.50'
,
'accuracy_top-1_thr_0.60'
,
'accuracy_top-2_thr_0.50'
,
'accuracy_top-2_thr_0.60'
},
eval_results
.
keys
())
self
.
assertIsInstance
(
eval_results
[
'precision_thr_0.50'
],
float
)
self
.
assertIsInstance
(
eval_results
[
'recall_thr_0.50'
],
float
)
self
.
assertIsInstance
(
eval_results
[
'f1_score_thr_0.50'
],
float
)
self
.
assertIsInstance
(
eval_results
[
'accuracy_top-1_thr_0.50'
],
float
)
# test topk is tuple while thrs is number
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
'accuracy'
,
metric_options
=
{
'thrs'
:
0.5
,
'topk'
:
(
1
,
2
)
})
self
.
assertEqual
({
'accuracy_top-1'
,
'accuracy_top-2'
},
eval_results
.
keys
())
self
.
assertIsInstance
(
eval_results
[
'accuracy_top-1'
],
float
)
# test topk is number while thrs is tuple
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
'accuracy'
,
metric_options
=
{
'thrs'
:
(
0.5
,
0.6
),
'topk'
:
1
})
self
.
assertEqual
({
'accuracy_thr_0.50'
,
'accuracy_thr_0.60'
},
eval_results
.
keys
())
self
.
assertIsInstance
(
eval_results
[
'accuracy_thr_0.50'
],
float
)
# test evaluation results for classes
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'support'
],
metric_options
=
{
'average_mode'
:
'none'
})
self
.
assertEqual
(
eval_results
[
'precision'
].
shape
,
(
3
,
))
self
.
assertEqual
(
eval_results
[
'recall'
].
shape
,
(
3
,
))
self
.
assertEqual
(
eval_results
[
'f1_score'
].
shape
,
(
3
,
))
self
.
assertEqual
(
eval_results
[
'support'
].
shape
,
(
3
,
))
# the average_mode method must be valid
with
self
.
assertRaises
(
ValueError
):
dataset
.
evaluate
(
fake_results
,
metric
=
[
'precision'
,
'recall'
,
'f1_score'
,
'support'
],
metric_options
=
{
'average_mode'
:
'micro'
})
# the metric must be valid for the dataset
with
self
.
assertRaisesRegex
(
ValueError
,
"{'unknown'} is not supported"
):
dataset
.
evaluate
(
fake_results
,
metric
=
'unknown'
)
class
TestMultiLabelDataset
(
TestBaseDataset
):
DATASET_TYPE
=
'MultiLabelDataset'
def
test_get_cat_ids
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
fake_ann
=
[
dict
(
img_prefix
=
''
,
img_info
=
dict
(),
gt_label
=
np
.
array
([
0
,
1
,
1
,
0
],
dtype
=
np
.
uint8
))
]
with
patch
.
object
(
dataset_class
,
'load_annotations'
)
as
mock_load
:
mock_load
.
return_value
=
fake_ann
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
cat_ids
=
dataset
.
get_cat_ids
(
0
)
self
.
assertIsInstance
(
cat_ids
,
list
)
self
.
assertEqual
(
len
(
cat_ids
),
2
)
self
.
assertIsInstance
(
cat_ids
[
0
],
int
)
self
.
assertEqual
(
cat_ids
,
[
1
,
2
])
def
test_evaluate
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
fake_ann
=
[
dict
(
gt_label
=
np
.
array
([
1
,
1
,
0
,
-
1
],
dtype
=
np
.
int8
)),
dict
(
gt_label
=
np
.
array
([
1
,
1
,
0
,
-
1
],
dtype
=
np
.
int8
)),
dict
(
gt_label
=
np
.
array
([
0
,
-
1
,
1
,
-
1
],
dtype
=
np
.
int8
)),
dict
(
gt_label
=
np
.
array
([
0
,
1
,
0
,
-
1
],
dtype
=
np
.
int8
)),
dict
(
gt_label
=
np
.
array
([
0
,
1
,
0
,
-
1
],
dtype
=
np
.
int8
)),
]
with
patch
.
object
(
dataset_class
,
'load_annotations'
)
as
mock_load
:
mock_load
.
return_value
=
fake_ann
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
fake_results
=
np
.
array
([
[
0.9
,
0.8
,
0.3
,
0.2
],
[
0.1
,
0.2
,
0.2
,
0.1
],
[
0.7
,
0.5
,
0.9
,
0.3
],
[
0.8
,
0.1
,
0.1
,
0.2
],
[
0.8
,
0.1
,
0.1
,
0.2
],
])
# the metric must be valid for the dataset
with
self
.
assertRaisesRegex
(
ValueError
,
"{'unknown'} is not supported"
):
dataset
.
evaluate
(
fake_results
,
metric
=
'unknown'
)
# only one metric
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
'mAP'
)
self
.
assertEqual
(
eval_results
.
keys
(),
{
'mAP'
})
self
.
assertAlmostEqual
(
eval_results
[
'mAP'
],
67.5
,
places
=
4
)
# multiple metrics
eval_results
=
dataset
.
evaluate
(
fake_results
,
metric
=
[
'mAP'
,
'CR'
,
'OF1'
])
self
.
assertEqual
(
eval_results
.
keys
(),
{
'mAP'
,
'CR'
,
'OF1'
})
self
.
assertAlmostEqual
(
eval_results
[
'mAP'
],
67.50
,
places
=
2
)
self
.
assertAlmostEqual
(
eval_results
[
'CR'
],
43.75
,
places
=
2
)
self
.
assertAlmostEqual
(
eval_results
[
'OF1'
],
42.86
,
places
=
2
)
class
TestCustomDataset
(
TestBaseDataset
):
DATASET_TYPE
=
'CustomDataset'
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# test load without ann_file
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'ann_file'
:
None
,
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
3
)
self
.
assertEqual
(
dataset
.
CLASSES
,
[
'a'
,
'b'
])
# auto infer classes
self
.
assertEqual
(
dataset
.
data_infos
[
0
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'a/1.JPG'
},
'gt_label'
:
np
.
array
(
0
)
})
self
.
assertEqual
(
dataset
.
data_infos
[
2
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'b/subb/3.jpg'
},
'gt_label'
:
np
.
array
(
1
)
})
# test ann_file assertion
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'ann_file'
:
[
'ann_file.txt'
],
}
with
self
.
assertRaisesRegex
(
TypeError
,
'must be a str'
):
dataset_class
(
**
cfg
)
# test load with ann_file
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'ann_file'
:
osp
.
join
(
ASSETS_ROOT
,
'ann.txt'
),
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
3
)
# custom dataset won't infer CLASSES from ann_file
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
self
.
assertEqual
(
dataset
.
data_infos
[
0
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'a/1.JPG'
},
'gt_label'
:
np
.
array
(
0
)
})
self
.
assertEqual
(
dataset
.
data_infos
[
2
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'b/subb/2.jpeg'
},
'gt_label'
:
np
.
array
(
1
)
})
# test extensions filter
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'ann_file'
:
None
,
'extensions'
:
(
'.txt'
,
)
}
with
self
.
assertRaisesRegex
(
RuntimeError
,
'Supported extensions are: .txt'
):
dataset_class
(
**
cfg
)
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'ann_file'
:
None
,
'extensions'
:
(
'.jpeg'
,
)
}
with
self
.
assertWarnsRegex
(
UserWarning
,
'Supported extensions are: .jpeg'
):
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
1
)
self
.
assertEqual
(
dataset
.
data_infos
[
0
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'b/2.jpeg'
},
'gt_label'
:
np
.
array
(
1
)
})
# test classes check
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'classes'
:
[
'apple'
,
'banana'
],
'ann_file'
:
None
,
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
[
'apple'
,
'banana'
])
cfg
[
'classes'
]
=
[
'apple'
,
'banana'
,
'dog'
]
with
self
.
assertRaisesRegex
(
AssertionError
,
r
"\(2\) doesn't match .* classes \(3\)"
):
dataset_class
(
**
cfg
)
class
TestImageNet
(
TestBaseDataset
):
DATASET_TYPE
=
'ImageNet'
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# test classes number
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'ann_file'
:
None
,
}
with
self
.
assertRaisesRegex
(
AssertionError
,
r
"\(2\) doesn't match .* classes \(1000\)"
):
dataset_class
(
**
cfg
)
# test override classes
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'data_prefix'
:
ASSETS_ROOT
,
'classes'
:
[
'cat'
,
'dog'
],
'ann_file'
:
None
,
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
3
)
self
.
assertEqual
(
dataset
.
CLASSES
,
[
'cat'
,
'dog'
])
class
TestImageNet21k
(
TestBaseDataset
):
DATASET_TYPE
=
'ImageNet21k'
DEFAULT_ARGS
=
dict
(
data_prefix
=
ASSETS_ROOT
,
pipeline
=
[],
classes
=
[
'cat'
,
'dog'
],
ann_file
=
osp
.
join
(
ASSETS_ROOT
,
'ann.txt'
),
serialize_data
=
False
)
def
test_initialize
(
self
):
super
().
test_initialize
()
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# The multi_label option is not implemented not.
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'multi_label'
:
True
}
with
self
.
assertRaisesRegex
(
NotImplementedError
,
'not supported'
):
dataset_class
(
**
cfg
)
# Warn about ann_file
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'ann_file'
:
None
}
with
self
.
assertWarnsRegex
(
UserWarning
,
'specify the `ann_file`'
):
dataset_class
(
**
cfg
)
# Warn about classes
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'classes'
:
None
}
with
self
.
assertWarnsRegex
(
UserWarning
,
'specify the `classes`'
):
dataset_class
(
**
cfg
)
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# Test with serialize_data=False
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'serialize_data'
:
False
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
.
data_infos
),
3
)
self
.
assertEqual
(
len
(
dataset
),
3
)
self
.
assertEqual
(
dataset
[
0
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'a/1.JPG'
},
'gt_label'
:
np
.
array
(
0
)
})
self
.
assertEqual
(
dataset
[
2
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'b/subb/2.jpeg'
},
'gt_label'
:
np
.
array
(
1
)
})
# Test with serialize_data=True
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'serialize_data'
:
True
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
.
data_infos
),
0
)
# data_infos is clear.
self
.
assertEqual
(
len
(
dataset
),
3
)
self
.
assertEqual
(
dataset
[
0
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'a/1.JPG'
},
'gt_label'
:
np
.
array
(
0
)
})
self
.
assertEqual
(
dataset
[
2
],
{
'img_prefix'
:
ASSETS_ROOT
,
'img_info'
:
{
'filename'
:
'b/subb/2.jpeg'
},
'gt_label'
:
np
.
array
(
1
)
})
class
TestMNIST
(
TestBaseDataset
):
DATASET_TYPE
=
'MNIST'
@
classmethod
def
setUpClass
(
cls
)
->
None
:
super
().
setUpClass
()
tmpdir
=
tempfile
.
TemporaryDirectory
()
cls
.
tmpdir
=
tmpdir
data_prefix
=
tmpdir
.
name
cls
.
DEFAULT_ARGS
=
dict
(
data_prefix
=
data_prefix
,
pipeline
=
[])
dataset_class
=
DATASETS
.
get
(
cls
.
DATASET_TYPE
)
def
rm_suffix
(
s
):
return
s
[:
s
.
rfind
(
'.'
)]
train_image_file
=
osp
.
join
(
data_prefix
,
rm_suffix
(
dataset_class
.
resources
[
'train_image_file'
][
0
]))
train_label_file
=
osp
.
join
(
data_prefix
,
rm_suffix
(
dataset_class
.
resources
[
'train_label_file'
][
0
]))
test_image_file
=
osp
.
join
(
data_prefix
,
rm_suffix
(
dataset_class
.
resources
[
'test_image_file'
][
0
]))
test_label_file
=
osp
.
join
(
data_prefix
,
rm_suffix
(
dataset_class
.
resources
[
'test_label_file'
][
0
]))
cls
.
fake_img
=
np
.
random
.
randint
(
0
,
255
,
size
=
(
28
,
28
),
dtype
=
np
.
uint8
)
cls
.
fake_label
=
np
.
random
.
randint
(
0
,
10
,
size
=
(
1
,
),
dtype
=
np
.
uint8
)
for
file
in
[
train_image_file
,
test_image_file
]:
magic
=
b
'
\x00\x00\x08\x03
'
# num_dims = 3, type = uint8
head
=
b
'
\x00\x00\x00\x01
'
+
b
'
\x00\x00\x00\x1c
'
*
2
# (1, 28, 28)
data
=
magic
+
head
+
cls
.
fake_img
.
flatten
().
tobytes
()
with
open
(
file
,
'wb'
)
as
f
:
f
.
write
(
data
)
for
file
in
[
train_label_file
,
test_label_file
]:
magic
=
b
'
\x00\x00\x08\x01
'
# num_dims = 3, type = uint8
head
=
b
'
\x00\x00\x00\x01
'
# (1, )
data
=
magic
+
head
+
cls
.
fake_label
.
tobytes
()
with
open
(
file
,
'wb'
)
as
f
:
f
.
write
(
data
)
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
with
patch
.
object
(
dataset_class
,
'download'
):
# Test default behavior
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
self
.
assertEqual
(
len
(
dataset
),
1
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img'
],
self
.
fake_img
)
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
self
.
fake_label
)
@
classmethod
def
tearDownClass
(
cls
):
cls
.
tmpdir
.
cleanup
()
class
TestCIFAR10
(
TestBaseDataset
):
DATASET_TYPE
=
'CIFAR10'
@
classmethod
def
setUpClass
(
cls
)
->
None
:
super
().
setUpClass
()
tmpdir
=
tempfile
.
TemporaryDirectory
()
cls
.
tmpdir
=
tmpdir
data_prefix
=
tmpdir
.
name
cls
.
DEFAULT_ARGS
=
dict
(
data_prefix
=
data_prefix
,
pipeline
=
[])
dataset_class
=
DATASETS
.
get
(
cls
.
DATASET_TYPE
)
base_folder
=
osp
.
join
(
data_prefix
,
dataset_class
.
base_folder
)
os
.
mkdir
(
base_folder
)
cls
.
fake_imgs
=
np
.
random
.
randint
(
0
,
255
,
size
=
(
6
,
3
*
32
*
32
),
dtype
=
np
.
uint8
)
cls
.
fake_labels
=
np
.
random
.
randint
(
0
,
10
,
size
=
(
6
,
))
cls
.
fake_classes
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
batch1
=
dict
(
data
=
cls
.
fake_imgs
[:
2
],
labels
=
cls
.
fake_labels
[:
2
].
tolist
())
with
open
(
osp
.
join
(
base_folder
,
'data_batch_1'
),
'wb'
)
as
f
:
f
.
write
(
pickle
.
dumps
(
batch1
))
batch2
=
dict
(
data
=
cls
.
fake_imgs
[
2
:
4
],
labels
=
cls
.
fake_labels
[
2
:
4
].
tolist
())
with
open
(
osp
.
join
(
base_folder
,
'data_batch_2'
),
'wb'
)
as
f
:
f
.
write
(
pickle
.
dumps
(
batch2
))
test_batch
=
dict
(
data
=
cls
.
fake_imgs
[
4
:],
labels
=
cls
.
fake_labels
[
4
:].
tolist
())
with
open
(
osp
.
join
(
base_folder
,
'test_batch'
),
'wb'
)
as
f
:
f
.
write
(
pickle
.
dumps
(
test_batch
))
meta
=
{
dataset_class
.
meta
[
'key'
]:
cls
.
fake_classes
}
meta_filename
=
dataset_class
.
meta
[
'filename'
]
with
open
(
osp
.
join
(
base_folder
,
meta_filename
),
'wb'
)
as
f
:
f
.
write
(
pickle
.
dumps
(
meta
))
dataset_class
.
train_list
=
[[
'data_batch_1'
,
None
],
[
'data_batch_2'
,
None
]]
dataset_class
.
test_list
=
[[
'test_batch'
,
None
]]
dataset_class
.
meta
[
'md5'
]
=
None
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# Test default behavior
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
self
.
assertEqual
(
len
(
dataset
),
4
)
self
.
assertEqual
(
dataset
.
CLASSES
,
self
.
fake_classes
)
data_info
=
dataset
[
0
]
fake_img
=
self
.
fake_imgs
[
0
].
reshape
(
3
,
32
,
32
).
transpose
(
1
,
2
,
0
)
np
.
testing
.
assert_equal
(
data_info
[
'img'
],
fake_img
)
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
self
.
fake_labels
[
0
])
# Test with test_mode=True
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
True
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
2
)
data_info
=
dataset
[
0
]
fake_img
=
self
.
fake_imgs
[
4
].
reshape
(
3
,
32
,
32
).
transpose
(
1
,
2
,
0
)
np
.
testing
.
assert_equal
(
data_info
[
'img'
],
fake_img
)
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
self
.
fake_labels
[
4
])
@
classmethod
def
tearDownClass
(
cls
):
cls
.
tmpdir
.
cleanup
()
class
TestCIFAR100
(
TestCIFAR10
):
DATASET_TYPE
=
'CIFAR100'
class
TestVOC
(
TestMultiLabelDataset
):
DATASET_TYPE
=
'VOC'
DEFAULT_ARGS
=
dict
(
data_prefix
=
'VOC2007'
,
pipeline
=
[])
class
TestCUB
(
TestBaseDataset
):
DATASET_TYPE
=
'CUB'
@
classmethod
def
setUpClass
(
cls
)
->
None
:
super
().
setUpClass
()
tmpdir
=
tempfile
.
TemporaryDirectory
()
cls
.
tmpdir
=
tmpdir
cls
.
data_prefix
=
tmpdir
.
name
cls
.
ann_file
=
osp
.
join
(
cls
.
data_prefix
,
'ann_file.txt'
)
cls
.
image_class_labels_file
=
osp
.
join
(
cls
.
data_prefix
,
'classes.txt'
)
cls
.
train_test_split_file
=
osp
.
join
(
cls
.
data_prefix
,
'split.txt'
)
cls
.
train_test_split_file2
=
osp
.
join
(
cls
.
data_prefix
,
'split2.txt'
)
cls
.
DEFAULT_ARGS
=
dict
(
data_prefix
=
cls
.
data_prefix
,
pipeline
=
[],
ann_file
=
cls
.
ann_file
,
image_class_labels_file
=
cls
.
image_class_labels_file
,
train_test_split_file
=
cls
.
train_test_split_file
)
with
open
(
cls
.
ann_file
,
'w'
)
as
f
:
f
.
write
(
'
\n
'
.
join
([
'1 1.txt'
,
'2 2.txt'
,
'3 3.txt'
,
]))
with
open
(
cls
.
image_class_labels_file
,
'w'
)
as
f
:
f
.
write
(
'
\n
'
.
join
([
'1 2'
,
'2 3'
,
'3 1'
,
]))
with
open
(
cls
.
train_test_split_file
,
'w'
)
as
f
:
f
.
write
(
'
\n
'
.
join
([
'1 0'
,
'2 1'
,
'3 1'
,
]))
with
open
(
cls
.
train_test_split_file2
,
'w'
)
as
f
:
f
.
write
(
'
\n
'
.
join
([
'1 0'
,
'2 1'
,
]))
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# Test default behavior
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
self
.
assertEqual
(
len
(
dataset
),
2
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img_prefix'
],
self
.
data_prefix
)
np
.
testing
.
assert_equal
(
data_info
[
'img_info'
],
{
'filename'
:
'2.txt'
})
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
3
-
1
)
# Test with test_mode=True
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
True
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
1
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img_prefix'
],
self
.
data_prefix
)
np
.
testing
.
assert_equal
(
data_info
[
'img_info'
],
{
'filename'
:
'1.txt'
})
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
2
-
1
)
# Test if the numbers of line are not match
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'train_test_split_file'
:
self
.
train_test_split_file2
}
with
self
.
assertRaisesRegex
(
AssertionError
,
'should have same length'
):
dataset_class
(
**
cfg
)
@
classmethod
def
tearDownClass
(
cls
):
cls
.
tmpdir
.
cleanup
()
class
TestStanfordCars
(
TestBaseDataset
):
DATASET_TYPE
=
'StanfordCars'
def
test_initialize
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
with
patch
.
object
(
dataset_class
,
'load_annotations'
):
# Test with test_mode=False, ann_file is None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
False
,
'ann_file'
:
None
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
self
.
assertFalse
(
dataset
.
test_mode
)
self
.
assertIsNone
(
dataset
.
ann_file
)
self
.
assertIsNotNone
(
dataset
.
train_ann_file
)
# Test with test_mode=False, ann_file is not None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
False
,
'ann_file'
:
'train_ann_file.mat'
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
self
.
assertFalse
(
dataset
.
test_mode
)
self
.
assertIsNotNone
(
dataset
.
ann_file
)
self
.
assertEqual
(
dataset
.
ann_file
,
'train_ann_file.mat'
)
self
.
assertIsNotNone
(
dataset
.
train_ann_file
)
# Test with test_mode=True, ann_file is None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
True
,
'ann_file'
:
None
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
self
.
assertTrue
(
dataset
.
test_mode
)
self
.
assertIsNone
(
dataset
.
ann_file
)
self
.
assertIsNotNone
(
dataset
.
test_ann_file
)
# Test with test_mode=True, ann_file is not None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
True
,
'ann_file'
:
'test_ann_file.mat'
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
self
.
assertTrue
(
dataset
.
test_mode
)
self
.
assertIsNotNone
(
dataset
.
ann_file
)
self
.
assertEqual
(
dataset
.
ann_file
,
'test_ann_file.mat'
)
self
.
assertIsNotNone
(
dataset
.
test_ann_file
)
@
classmethod
def
setUpClass
(
cls
)
->
None
:
super
().
setUpClass
()
tmpdir
=
tempfile
.
TemporaryDirectory
()
cls
.
tmpdir
=
tmpdir
cls
.
data_prefix
=
tmpdir
.
name
cls
.
ann_file
=
None
devkit
=
osp
.
join
(
cls
.
data_prefix
,
'devkit'
)
if
not
osp
.
exists
(
devkit
):
os
.
mkdir
(
devkit
)
cls
.
train_ann_file
=
osp
.
join
(
devkit
,
'cars_train_annos.mat'
)
cls
.
test_ann_file
=
osp
.
join
(
devkit
,
'cars_test_annos_withlabels.mat'
)
cls
.
DEFAULT_ARGS
=
dict
(
data_prefix
=
cls
.
data_prefix
,
pipeline
=
[],
test_mode
=
False
)
try
:
import
scipy.io
as
sio
except
ImportError
:
raise
ImportError
(
'please run `pip install scipy` to install package `scipy`.'
)
sio
.
savemat
(
cls
.
train_ann_file
,
{
'annotations'
:
[(
(
np
.
array
([
1
]),
np
.
array
([
10
]),
np
.
array
(
[
20
]),
np
.
array
([
50
]),
15
,
np
.
array
([
'001.jpg'
])),
(
np
.
array
([
2
]),
np
.
array
([
15
]),
np
.
array
(
[
240
]),
np
.
array
([
250
]),
15
,
np
.
array
([
'002.jpg'
])),
(
np
.
array
([
89
]),
np
.
array
([
150
]),
np
.
array
(
[
278
]),
np
.
array
([
388
]),
150
,
np
.
array
([
'012.jpg'
])),
)]
})
sio
.
savemat
(
cls
.
test_ann_file
,
{
'annotations'
:
[((
np
.
array
([
89
]),
np
.
array
([
150
]),
np
.
array
(
[
278
]),
np
.
array
([
388
]),
150
,
np
.
array
([
'025.jpg'
])),
(
np
.
array
([
155
]),
np
.
array
([
10
]),
np
.
array
(
[
200
]),
np
.
array
([
233
]),
0
,
np
.
array
([
'111.jpg'
])),
(
np
.
array
([
25
]),
np
.
array
([
115
]),
np
.
array
(
[
240
]),
np
.
array
([
360
]),
15
,
np
.
array
([
'265.jpg'
])))]
})
def
test_load_annotations
(
self
):
dataset_class
=
DATASETS
.
get
(
self
.
DATASET_TYPE
)
# Test with test_mode=False and ann_file=None
dataset
=
dataset_class
(
**
self
.
DEFAULT_ARGS
)
self
.
assertEqual
(
len
(
dataset
),
3
)
self
.
assertEqual
(
dataset
.
CLASSES
,
dataset_class
.
CLASSES
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img_prefix'
],
osp
.
join
(
self
.
data_prefix
,
'cars_train'
))
np
.
testing
.
assert_equal
(
data_info
[
'img_info'
],
{
'filename'
:
'001.jpg'
})
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
15
-
1
)
# Test with test_mode=True and ann_file=None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
True
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
3
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img_prefix'
],
osp
.
join
(
self
.
data_prefix
,
'cars_test'
))
np
.
testing
.
assert_equal
(
data_info
[
'img_info'
],
{
'filename'
:
'025.jpg'
})
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
150
-
1
)
# Test with test_mode=False, ann_file is not None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
False
,
'ann_file'
:
self
.
train_ann_file
}
dataset
=
dataset_class
(
**
cfg
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img_prefix'
],
osp
.
join
(
self
.
data_prefix
,
'cars_train'
))
np
.
testing
.
assert_equal
(
data_info
[
'img_info'
],
{
'filename'
:
'001.jpg'
})
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
15
-
1
)
# Test with test_mode=True, ann_file is not None
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'test_mode'
:
True
,
'ann_file'
:
self
.
test_ann_file
}
dataset
=
dataset_class
(
**
cfg
)
self
.
assertEqual
(
len
(
dataset
),
3
)
data_info
=
dataset
[
0
]
np
.
testing
.
assert_equal
(
data_info
[
'img_prefix'
],
osp
.
join
(
self
.
data_prefix
,
'cars_test'
))
np
.
testing
.
assert_equal
(
data_info
[
'img_info'
],
{
'filename'
:
'025.jpg'
})
np
.
testing
.
assert_equal
(
data_info
[
'gt_label'
],
150
-
1
)
@
classmethod
def
tearDownClass
(
cls
):
cls
.
tmpdir
.
cleanup
()
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_dataset_utils.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
random
import
string
from
mmcls.datasets.utils
import
check_integrity
,
rm_suffix
def
test_dataset_utils
():
# test rm_suffix
assert
rm_suffix
(
'a.jpg'
)
==
'a'
assert
rm_suffix
(
'a.bak.jpg'
)
==
'a.bak'
assert
rm_suffix
(
'a.bak.jpg'
,
suffix
=
'.jpg'
)
==
'a.bak'
assert
rm_suffix
(
'a.bak.jpg'
,
suffix
=
'.bak.jpg'
)
==
'a'
# test check_integrity
rand_file
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
,
10
))
assert
not
check_integrity
(
rand_file
,
md5
=
None
)
assert
not
check_integrity
(
rand_file
,
md5
=
2333
)
test_file
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/color.jpg'
)
assert
check_integrity
(
test_file
,
md5
=
'08252e5100cb321fe74e0e12a724ce14'
)
assert
not
check_integrity
(
test_file
,
md5
=
2333
)
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_dataset_wrapper.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
bisect
import
math
from
collections
import
defaultdict
from
unittest.mock
import
MagicMock
,
patch
import
numpy
as
np
import
pytest
from
mmcls.datasets
import
(
BaseDataset
,
ClassBalancedDataset
,
ConcatDataset
,
KFoldDataset
,
RepeatDataset
)
def
mock_evaluate
(
results
,
metric
=
'accuracy'
,
metric_options
=
None
,
indices
=
None
,
logger
=
None
):
return
dict
(
results
=
results
,
metric
=
metric
,
metric_options
=
metric_options
,
indices
=
indices
,
logger
=
logger
)
@
patch
.
multiple
(
BaseDataset
,
__abstractmethods__
=
set
())
def
construct_toy_multi_label_dataset
(
length
):
BaseDataset
.
CLASSES
=
(
'foo'
,
'bar'
)
BaseDataset
.
__getitem__
=
MagicMock
(
side_effect
=
lambda
idx
:
idx
)
dataset
=
BaseDataset
(
data_prefix
=
''
,
pipeline
=
[],
test_mode
=
True
)
cat_ids_list
=
[
np
.
random
.
randint
(
0
,
80
,
num
).
tolist
()
for
num
in
np
.
random
.
randint
(
1
,
20
,
length
)
]
dataset
.
data_infos
=
MagicMock
()
dataset
.
data_infos
.
__len__
.
return_value
=
length
dataset
.
get_cat_ids
=
MagicMock
(
side_effect
=
lambda
idx
:
cat_ids_list
[
idx
])
dataset
.
get_gt_labels
=
\
MagicMock
(
side_effect
=
lambda
:
np
.
array
(
cat_ids_list
))
dataset
.
evaluate
=
MagicMock
(
side_effect
=
mock_evaluate
)
return
dataset
,
cat_ids_list
@
patch
.
multiple
(
BaseDataset
,
__abstractmethods__
=
set
())
def
construct_toy_single_label_dataset
(
length
):
BaseDataset
.
CLASSES
=
(
'foo'
,
'bar'
)
BaseDataset
.
__getitem__
=
MagicMock
(
side_effect
=
lambda
idx
:
idx
)
dataset
=
BaseDataset
(
data_prefix
=
''
,
pipeline
=
[],
test_mode
=
True
)
cat_ids_list
=
[[
np
.
random
.
randint
(
0
,
80
)]
for
_
in
range
(
length
)]
dataset
.
data_infos
=
MagicMock
()
dataset
.
data_infos
.
__len__
.
return_value
=
length
dataset
.
get_cat_ids
=
MagicMock
(
side_effect
=
lambda
idx
:
cat_ids_list
[
idx
])
dataset
.
get_gt_labels
=
\
MagicMock
(
side_effect
=
lambda
:
cat_ids_list
)
dataset
.
evaluate
=
MagicMock
(
side_effect
=
mock_evaluate
)
return
dataset
,
cat_ids_list
@
pytest
.
mark
.
parametrize
(
'construct_dataset'
,
[
'construct_toy_multi_label_dataset'
,
'construct_toy_single_label_dataset'
])
def
test_concat_dataset
(
construct_dataset
):
construct_toy_dataset
=
eval
(
construct_dataset
)
dataset_a
,
cat_ids_list_a
=
construct_toy_dataset
(
10
)
dataset_b
,
cat_ids_list_b
=
construct_toy_dataset
(
20
)
concat_dataset
=
ConcatDataset
([
dataset_a
,
dataset_b
])
assert
concat_dataset
[
5
]
==
5
assert
concat_dataset
[
25
]
==
15
assert
concat_dataset
.
get_cat_ids
(
5
)
==
cat_ids_list_a
[
5
]
assert
concat_dataset
.
get_cat_ids
(
25
)
==
cat_ids_list_b
[
15
]
assert
len
(
concat_dataset
)
==
len
(
dataset_a
)
+
len
(
dataset_b
)
assert
concat_dataset
.
CLASSES
==
BaseDataset
.
CLASSES
@
pytest
.
mark
.
parametrize
(
'construct_dataset'
,
[
'construct_toy_multi_label_dataset'
,
'construct_toy_single_label_dataset'
])
def
test_repeat_dataset
(
construct_dataset
):
construct_toy_dataset
=
eval
(
construct_dataset
)
dataset
,
cat_ids_list
=
construct_toy_dataset
(
10
)
repeat_dataset
=
RepeatDataset
(
dataset
,
10
)
assert
repeat_dataset
[
5
]
==
5
assert
repeat_dataset
[
15
]
==
5
assert
repeat_dataset
[
27
]
==
7
assert
repeat_dataset
.
get_cat_ids
(
5
)
==
cat_ids_list
[
5
]
assert
repeat_dataset
.
get_cat_ids
(
15
)
==
cat_ids_list
[
5
]
assert
repeat_dataset
.
get_cat_ids
(
27
)
==
cat_ids_list
[
7
]
assert
len
(
repeat_dataset
)
==
10
*
len
(
dataset
)
assert
repeat_dataset
.
CLASSES
==
BaseDataset
.
CLASSES
@
pytest
.
mark
.
parametrize
(
'construct_dataset'
,
[
'construct_toy_multi_label_dataset'
,
'construct_toy_single_label_dataset'
])
def
test_class_balanced_dataset
(
construct_dataset
):
construct_toy_dataset
=
eval
(
construct_dataset
)
dataset
,
cat_ids_list
=
construct_toy_dataset
(
10
)
category_freq
=
defaultdict
(
int
)
for
cat_ids
in
cat_ids_list
:
cat_ids
=
set
(
cat_ids
)
for
cat_id
in
cat_ids
:
category_freq
[
cat_id
]
+=
1
for
k
,
v
in
category_freq
.
items
():
category_freq
[
k
]
=
v
/
len
(
cat_ids_list
)
mean_freq
=
np
.
mean
(
list
(
category_freq
.
values
()))
repeat_thr
=
mean_freq
category_repeat
=
{
cat_id
:
max
(
1.0
,
math
.
sqrt
(
repeat_thr
/
cat_freq
))
for
cat_id
,
cat_freq
in
category_freq
.
items
()
}
repeat_factors
=
[]
for
cat_ids
in
cat_ids_list
:
cat_ids
=
set
(
cat_ids
)
repeat_factor
=
max
({
category_repeat
[
cat_id
]
for
cat_id
in
cat_ids
})
repeat_factors
.
append
(
math
.
ceil
(
repeat_factor
))
repeat_factors_cumsum
=
np
.
cumsum
(
repeat_factors
)
repeat_factor_dataset
=
ClassBalancedDataset
(
dataset
,
repeat_thr
)
assert
repeat_factor_dataset
.
CLASSES
==
BaseDataset
.
CLASSES
assert
len
(
repeat_factor_dataset
)
==
repeat_factors_cumsum
[
-
1
]
for
idx
in
np
.
random
.
randint
(
0
,
len
(
repeat_factor_dataset
),
3
):
assert
repeat_factor_dataset
[
idx
]
==
bisect
.
bisect_right
(
repeat_factors_cumsum
,
idx
)
@
pytest
.
mark
.
parametrize
(
'construct_dataset'
,
[
'construct_toy_multi_label_dataset'
,
'construct_toy_single_label_dataset'
])
def
test_kfold_dataset
(
construct_dataset
):
construct_toy_dataset
=
eval
(
construct_dataset
)
dataset
,
cat_ids_list
=
construct_toy_dataset
(
10
)
# test without random seed
train_datasets
=
[
KFoldDataset
(
dataset
,
fold
=
i
,
num_splits
=
3
,
test_mode
=
False
)
for
i
in
range
(
5
)
]
test_datasets
=
[
KFoldDataset
(
dataset
,
fold
=
i
,
num_splits
=
3
,
test_mode
=
True
)
for
i
in
range
(
5
)
]
assert
sum
([
i
.
indices
for
i
in
test_datasets
],
[])
==
list
(
range
(
10
))
for
train_set
,
test_set
in
zip
(
train_datasets
,
test_datasets
):
train_samples
=
[
train_set
[
i
]
for
i
in
range
(
len
(
train_set
))]
test_samples
=
[
test_set
[
i
]
for
i
in
range
(
len
(
test_set
))]
assert
set
(
train_samples
+
test_samples
)
==
set
(
range
(
10
))
# test with random seed
train_datasets
=
[
KFoldDataset
(
dataset
,
fold
=
i
,
num_splits
=
3
,
test_mode
=
False
,
seed
=
1
)
for
i
in
range
(
5
)
]
test_datasets
=
[
KFoldDataset
(
dataset
,
fold
=
i
,
num_splits
=
3
,
test_mode
=
True
,
seed
=
1
)
for
i
in
range
(
5
)
]
assert
sum
([
i
.
indices
for
i
in
test_datasets
],
[])
!=
list
(
range
(
10
))
assert
set
(
sum
([
i
.
indices
for
i
in
test_datasets
],
[]))
==
set
(
range
(
10
))
for
train_set
,
test_set
in
zip
(
train_datasets
,
test_datasets
):
train_samples
=
[
train_set
[
i
]
for
i
in
range
(
len
(
train_set
))]
test_samples
=
[
test_set
[
i
]
for
i
in
range
(
len
(
test_set
))]
assert
set
(
train_samples
+
test_samples
)
==
set
(
range
(
10
))
# test behavior of get_cat_ids method
for
train_set
,
test_set
in
zip
(
train_datasets
,
test_datasets
):
for
i
in
range
(
len
(
train_set
)):
cat_ids
=
train_set
.
get_cat_ids
(
i
)
assert
cat_ids
==
cat_ids_list
[
train_set
.
indices
[
i
]]
for
i
in
range
(
len
(
test_set
)):
cat_ids
=
test_set
.
get_cat_ids
(
i
)
assert
cat_ids
==
cat_ids_list
[
test_set
.
indices
[
i
]]
# test behavior of get_gt_labels method
for
train_set
,
test_set
in
zip
(
train_datasets
,
test_datasets
):
for
i
in
range
(
len
(
train_set
)):
gt_label
=
train_set
.
get_gt_labels
()[
i
]
assert
gt_label
==
cat_ids_list
[
train_set
.
indices
[
i
]]
for
i
in
range
(
len
(
test_set
)):
gt_label
=
test_set
.
get_gt_labels
()[
i
]
assert
gt_label
==
cat_ids_list
[
test_set
.
indices
[
i
]]
# test evaluate
for
test_set
in
test_datasets
:
eval_inputs
=
test_set
.
evaluate
(
None
)
assert
eval_inputs
[
'indices'
]
==
test_set
.
indices
openmmlab_test/mmclassification-0.24.1/tests/test_data/test_datasets/test_sampler.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest.mock
import
MagicMock
,
patch
import
numpy
as
np
from
mmcls.datasets
import
BaseDataset
,
RepeatAugSampler
,
build_sampler
@
patch
.
multiple
(
BaseDataset
,
__abstractmethods__
=
set
())
def
construct_toy_single_label_dataset
(
length
):
BaseDataset
.
CLASSES
=
(
'foo'
,
'bar'
)
BaseDataset
.
__getitem__
=
MagicMock
(
side_effect
=
lambda
idx
:
idx
)
dataset
=
BaseDataset
(
data_prefix
=
''
,
pipeline
=
[],
test_mode
=
True
)
cat_ids_list
=
[[
np
.
random
.
randint
(
0
,
80
)]
for
_
in
range
(
length
)]
dataset
.
data_infos
=
MagicMock
()
dataset
.
data_infos
.
__len__
.
return_value
=
length
dataset
.
get_cat_ids
=
MagicMock
(
side_effect
=
lambda
idx
:
cat_ids_list
[
idx
])
return
dataset
,
cat_ids_list
@
patch
(
'mmcls.datasets.samplers.repeat_aug.get_dist_info'
,
return_value
=
(
0
,
1
))
def
test_sampler_builder
(
_
):
assert
build_sampler
(
None
)
is
None
dataset
=
construct_toy_single_label_dataset
(
1000
)[
0
]
build_sampler
(
dict
(
type
=
'RepeatAugSampler'
,
dataset
=
dataset
))
@
patch
(
'mmcls.datasets.samplers.repeat_aug.get_dist_info'
,
return_value
=
(
0
,
1
))
def
test_rep_aug
(
_
):
dataset
=
construct_toy_single_label_dataset
(
1000
)[
0
]
ra
=
RepeatAugSampler
(
dataset
,
selected_round
=
0
,
shuffle
=
False
)
ra
.
set_epoch
(
0
)
assert
len
(
ra
)
==
1000
ra
=
RepeatAugSampler
(
dataset
)
assert
len
(
ra
)
==
768
val
=
None
for
idx
,
content
in
enumerate
(
ra
):
if
idx
%
3
==
0
:
val
=
content
else
:
assert
val
is
not
None
assert
content
==
val
@
patch
(
'mmcls.datasets.samplers.repeat_aug.get_dist_info'
,
return_value
=
(
0
,
2
))
def
test_rep_aug_dist
(
_
):
dataset
=
construct_toy_single_label_dataset
(
1000
)[
0
]
ra
=
RepeatAugSampler
(
dataset
,
selected_round
=
0
,
shuffle
=
False
)
ra
.
set_epoch
(
0
)
assert
len
(
ra
)
==
1000
//
2
ra
=
RepeatAugSampler
(
dataset
)
assert
len
(
ra
)
==
768
//
2
openmmlab_test/mmclassification-
speed-benchmark/tests
/test_pipelines/test_auto_augment.py
→
openmmlab_test/mmclassification-
0.24.1/tests/test_data
/test_pipelines/test_auto_augment.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
random
...
...
@@ -39,6 +40,47 @@ def construct_toy_data_photometric():
return
results
def
test_auto_augment
():
policies
=
[[
dict
(
type
=
'Posterize'
,
bits
=
4
,
prob
=
0.4
),
dict
(
type
=
'Rotate'
,
angle
=
30.
,
prob
=
0.6
)
]]
# test assertion for policies
with
pytest
.
raises
(
AssertionError
):
# policies shouldn't be empty
transform
=
dict
(
type
=
'AutoAugment'
,
policies
=
[])
build_from_cfg
(
transform
,
PIPELINES
)
with
pytest
.
raises
(
AssertionError
):
# policy should have type
invalid_policies
=
copy
.
deepcopy
(
policies
)
invalid_policies
[
0
][
0
].
pop
(
'type'
)
transform
=
dict
(
type
=
'AutoAugment'
,
policies
=
invalid_policies
)
build_from_cfg
(
transform
,
PIPELINES
)
with
pytest
.
raises
(
AssertionError
):
# sub policy should be a non-empty list
invalid_policies
=
copy
.
deepcopy
(
policies
)
invalid_policies
[
0
]
=
[]
transform
=
dict
(
type
=
'AutoAugment'
,
policies
=
invalid_policies
)
build_from_cfg
(
transform
,
PIPELINES
)
with
pytest
.
raises
(
AssertionError
):
# policy should be valid in PIPELINES registry.
invalid_policies
=
copy
.
deepcopy
(
policies
)
invalid_policies
.
append
([
dict
(
type
=
'Wrong_policy'
)])
transform
=
dict
(
type
=
'AutoAugment'
,
policies
=
invalid_policies
)
build_from_cfg
(
transform
,
PIPELINES
)
# test hparams
transform
=
dict
(
type
=
'AutoAugment'
,
policies
=
policies
,
hparams
=
dict
(
pad_val
=
15
,
interpolation
=
'nearest'
))
pipeline
=
build_from_cfg
(
transform
,
PIPELINES
)
# use hparams if not set in policies config
assert
pipeline
.
policies
[
0
][
1
][
'pad_val'
]
==
15
assert
pipeline
.
policies
[
0
][
1
][
'interpolation'
]
==
'nearest'
def
test_rand_augment
():
policies
=
[
dict
(
...
...
@@ -47,12 +89,13 @@ def test_rand_augment():
magnitude_range
=
(
0
,
1
),
pad_val
=
128
,
prob
=
1.
,
direction
=
'horizontal'
),
direction
=
'horizontal'
,
interpolation
=
'nearest'
),
dict
(
type
=
'Invert'
,
prob
=
1.
),
dict
(
type
=
'Rotate'
,
magnitude_key
=
'angle'
,
magnitude_range
=
(
0
,
3
0
),
magnitude_range
=
(
0
,
9
0
),
prob
=
0.
)
]
# test assertion for num_policies
...
...
@@ -136,6 +179,15 @@ def test_rand_augment():
num_policies
=
2
,
magnitude_level
=
12
)
build_from_cfg
(
transform
,
PIPELINES
)
with
pytest
.
raises
(
AssertionError
):
invalid_policies
=
copy
.
deepcopy
(
policies
)
invalid_policies
.
append
(
dict
(
type
=
'Wrong_policy'
))
transform
=
dict
(
type
=
'RandAugment'
,
policies
=
invalid_policies
,
num_policies
=
2
,
magnitude_level
=
12
)
build_from_cfg
(
transform
,
PIPELINES
)
with
pytest
.
raises
(
AssertionError
):
invalid_policies
=
copy
.
deepcopy
(
policies
)
invalid_policies
[
2
].
pop
(
'type'
)
...
...
@@ -306,7 +358,7 @@ def test_rand_augment():
axis
=-
1
)
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
img_augmented
)
# test case where magnitude_std is negtive
# test case where magnitude_std is neg
a
tive
random
.
seed
(
3
)
np
.
random
.
seed
(
0
)
results
=
construct_toy_data
()
...
...
@@ -326,6 +378,32 @@ def test_rand_augment():
axis
=-
1
)
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
img_augmented
)
# test hparams
random
.
seed
(
8
)
np
.
random
.
seed
(
0
)
results
=
construct_toy_data
()
policies
[
2
][
'prob'
]
=
1.0
transform
=
dict
(
type
=
'RandAugment'
,
policies
=
policies
,
num_policies
=
2
,
magnitude_level
=
12
,
magnitude_std
=-
1
,
hparams
=
dict
(
pad_val
=
15
,
interpolation
=
'nearest'
))
pipeline
=
build_from_cfg
(
transform
,
PIPELINES
)
# apply translate (magnitude=0.4) and rotate (angle=36)
results
=
pipeline
(
results
)
img_augmented
=
np
.
array
(
[[
128
,
128
,
128
,
15
],
[
128
,
128
,
5
,
2
],
[
15
,
9
,
9
,
6
]],
dtype
=
np
.
uint8
)
img_augmented
=
np
.
stack
([
img_augmented
,
img_augmented
,
img_augmented
],
axis
=-
1
)
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
img_augmented
)
# hparams won't override setting in policies config
assert
pipeline
.
policies
[
0
][
'pad_val'
]
==
128
# use hparams if not set in policies config
assert
pipeline
.
policies
[
2
][
'pad_val'
]
==
15
assert
pipeline
.
policies
[
2
][
'interpolation'
]
==
'nearest'
def
test_shear
():
# test assertion for invalid type of magnitude
...
...
@@ -524,7 +602,7 @@ def test_rotate():
transform
=
dict
(
type
=
'Rotate'
,
angle
=
90.
,
center
=
0
)
build_from_cfg
(
transform
,
PIPELINES
)
# test assertion for invalid lenth of center
# test assertion for invalid len
g
th of center
with
pytest
.
raises
(
AssertionError
):
transform
=
dict
(
type
=
'Rotate'
,
angle
=
90.
,
center
=
(
0
,
))
build_from_cfg
(
transform
,
PIPELINES
)
...
...
@@ -682,7 +760,7 @@ def test_equalize(nb_rand_test=100):
def
_imequalize
(
img
):
# equalize the image using PIL.ImageOps.equalize
from
PIL
import
Image
Ops
,
Image
from
PIL
import
Image
,
Image
Ops
img
=
Image
.
fromarray
(
img
)
equalized_img
=
np
.
asarray
(
ImageOps
.
equalize
(
img
))
return
equalized_img
...
...
@@ -704,7 +782,7 @@ def test_equalize(nb_rand_test=100):
transform
=
dict
(
type
=
'Equalize'
,
prob
=
1.
)
pipeline
=
build_from_cfg
(
transform
,
PIPELINES
)
for
_
in
range
(
nb_rand_test
):
img
=
np
.
clip
(
np
.
random
.
normal
(
0
,
1
,
(
1000
,
1200
,
3
))
*
260
,
0
,
img
=
np
.
clip
(
np
.
random
.
normal
(
0
,
1
,
(
256
,
256
,
3
))
*
260
,
0
,
255
).
astype
(
np
.
uint8
)
results
[
'img'
]
=
img
results
=
pipeline
(
copy
.
deepcopy
(
results
))
...
...
@@ -854,8 +932,9 @@ def test_posterize():
def
test_contrast
(
nb_rand_test
=
100
):
def
_adjust_contrast
(
img
,
factor
):
from
PIL.ImageEnhance
import
Contrast
from
PIL
import
Image
from
PIL.ImageEnhance
import
Contrast
# Image.fromarray defaultly supports RGB, not BGR.
# convert from BGR to RGB
img
=
Image
.
fromarray
(
img
[...,
::
-
1
],
mode
=
'RGB'
)
...
...
@@ -903,7 +982,7 @@ def test_contrast(nb_rand_test=100):
prob
=
1.
,
random_negative_prob
=
0.
)
pipeline
=
build_from_cfg
(
transform
,
PIPELINES
)
img
=
np
.
clip
(
np
.
random
.
uniform
(
0
,
1
,
(
1200
,
1000
,
3
))
*
260
,
0
,
img
=
np
.
clip
(
np
.
random
.
uniform
(
0
,
1
,
(
256
,
256
,
3
))
*
260
,
0
,
255
).
astype
(
np
.
uint8
)
results
[
'img'
]
=
img
results
=
pipeline
(
copy
.
deepcopy
(
results
))
...
...
@@ -988,8 +1067,8 @@ def test_brightness(nb_rand_test=100):
def
_adjust_brightness
(
img
,
factor
):
# adjust the brightness of image using
# PIL.ImageEnhance.Brightness
from
PIL.ImageEnhance
import
Brightness
from
PIL
import
Image
from
PIL.ImageEnhance
import
Brightness
img
=
Image
.
fromarray
(
img
)
brightened_img
=
Brightness
(
img
).
enhance
(
factor
)
return
np
.
asarray
(
brightened_img
)
...
...
@@ -1034,7 +1113,7 @@ def test_brightness(nb_rand_test=100):
prob
=
1.
,
random_negative_prob
=
0.
)
pipeline
=
build_from_cfg
(
transform
,
PIPELINES
)
img
=
np
.
clip
(
np
.
random
.
uniform
(
0
,
1
,
(
1200
,
1000
,
3
))
*
260
,
0
,
img
=
np
.
clip
(
np
.
random
.
uniform
(
0
,
1
,
(
256
,
256
,
3
))
*
260
,
0
,
255
).
astype
(
np
.
uint8
)
results
[
'img'
]
=
img
results
=
pipeline
(
copy
.
deepcopy
(
results
))
...
...
@@ -1050,8 +1129,8 @@ def test_sharpness(nb_rand_test=100):
def
_adjust_sharpness
(
img
,
factor
):
# adjust the sharpness of image using
# PIL.ImageEnhance.Sharpness
from
PIL.ImageEnhance
import
Sharpness
from
PIL
import
Image
from
PIL.ImageEnhance
import
Sharpness
img
=
Image
.
fromarray
(
img
)
sharpened_img
=
Sharpness
(
img
).
enhance
(
factor
)
return
np
.
asarray
(
sharpened_img
)
...
...
@@ -1096,7 +1175,7 @@ def test_sharpness(nb_rand_test=100):
prob
=
1.
,
random_negative_prob
=
0.
)
pipeline
=
build_from_cfg
(
transform
,
PIPELINES
)
img
=
np
.
clip
(
np
.
random
.
uniform
(
0
,
1
,
(
1200
,
1000
,
3
))
*
260
,
0
,
img
=
np
.
clip
(
np
.
random
.
uniform
(
0
,
1
,
(
256
,
256
,
3
))
*
260
,
0
,
255
).
astype
(
np
.
uint8
)
results
[
'img'
]
=
img
results
=
pipeline
(
copy
.
deepcopy
(
results
))
...
...
openmmlab_test/mmclassification-
speed-benchmark/tests
/test_pipelines/test_loading.py
→
openmmlab_test/mmclassification-
0.24.1/tests/test_data
/test_pipelines/test_loading.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
os.path
as
osp
...
...
@@ -10,7 +11,7 @@ class TestLoading(object):
@
classmethod
def
setup_class
(
cls
):
cls
.
data_prefix
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../data'
)
cls
.
data_prefix
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../
../
data'
)
def
test_load_img
(
self
):
results
=
dict
(
...
...
Prev
1
…
37
38
39
40
41
42
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment