Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
1baf0566
Commit
1baf0566
authored
Jun 24, 2025
by
limm
Browse files
add tests part
parent
495d9ed9
Pipeline
#2800
canceled with stages
Changes
146
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
785 additions
and
0 deletions
+785
-0
tests/test_structures/test_utils.py
tests/test_structures/test_utils.py
+63
-0
tests/test_tools.py
tests/test_tools.py
+418
-0
tests/test_utils/test_analyze.py
tests/test_utils/test_analyze.py
+43
-0
tests/test_utils/test_setup_env.py
tests/test_utils/test_setup_env.py
+40
-0
tests/test_utils/test_version_utils.py
tests/test_utils/test_version_utils.py
+21
-0
tests/test_visualization/test_visualizer.py
tests/test_visualization/test_visualizer.py
+200
-0
No files found.
tests/test_structures/test_utils.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
torch
from
mmpretrain.structures
import
(
batch_label_to_onehot
,
cat_batch_labels
,
tensor_split
)
class
TestStructureUtils
(
TestCase
):
def
test_tensor_split
(
self
):
tensor
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
])
split_indices
=
[
0
,
2
,
6
,
6
]
outs
=
tensor_split
(
tensor
,
split_indices
)
self
.
assertEqual
(
len
(
outs
),
len
(
split_indices
)
+
1
)
self
.
assertEqual
(
outs
[
0
].
size
(
0
),
0
)
self
.
assertEqual
(
outs
[
1
].
size
(
0
),
2
)
self
.
assertEqual
(
outs
[
2
].
size
(
0
),
4
)
self
.
assertEqual
(
outs
[
3
].
size
(
0
),
0
)
self
.
assertEqual
(
outs
[
4
].
size
(
0
),
1
)
tensor
=
torch
.
tensor
([])
split_indices
=
[
0
,
0
,
0
,
0
]
outs
=
tensor_split
(
tensor
,
split_indices
)
self
.
assertEqual
(
len
(
outs
),
len
(
split_indices
)
+
1
)
def
test_cat_batch_labels
(
self
):
labels
=
[
torch
.
tensor
([
1
]),
torch
.
tensor
([
3
,
2
]),
torch
.
tensor
([
0
,
1
,
4
]),
torch
.
tensor
([],
dtype
=
torch
.
int64
),
torch
.
tensor
([],
dtype
=
torch
.
int64
),
]
batch_label
,
split_indices
=
cat_batch_labels
(
labels
)
self
.
assertEqual
(
split_indices
,
[
1
,
3
,
6
,
6
])
self
.
assertEqual
(
len
(
batch_label
),
6
)
labels
=
tensor_split
(
batch_label
,
split_indices
)
self
.
assertEqual
(
labels
[
0
].
tolist
(),
[
1
])
self
.
assertEqual
(
labels
[
1
].
tolist
(),
[
3
,
2
])
self
.
assertEqual
(
labels
[
2
].
tolist
(),
[
0
,
1
,
4
])
self
.
assertEqual
(
labels
[
3
].
tolist
(),
[])
self
.
assertEqual
(
labels
[
4
].
tolist
(),
[])
def
test_batch_label_to_onehot
(
self
):
labels
=
[
torch
.
tensor
([
1
]),
torch
.
tensor
([
3
,
2
]),
torch
.
tensor
([
0
,
1
,
4
]),
torch
.
tensor
([],
dtype
=
torch
.
int64
),
torch
.
tensor
([],
dtype
=
torch
.
int64
),
]
batch_label
,
split_indices
=
cat_batch_labels
(
labels
)
batch_score
=
batch_label_to_onehot
(
batch_label
,
split_indices
,
num_classes
=
5
)
self
.
assertEqual
(
batch_score
[
0
].
tolist
(),
[
0
,
1
,
0
,
0
,
0
])
self
.
assertEqual
(
batch_score
[
1
].
tolist
(),
[
0
,
0
,
1
,
1
,
0
])
self
.
assertEqual
(
batch_score
[
2
].
tolist
(),
[
1
,
1
,
0
,
0
,
1
])
self
.
assertEqual
(
batch_score
[
3
].
tolist
(),
[
0
,
0
,
0
,
0
,
0
])
self
.
assertEqual
(
batch_score
[
4
].
tolist
(),
[
0
,
0
,
0
,
0
,
0
])
tests/test_tools.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
re
import
tempfile
from
collections
import
OrderedDict
from
pathlib
import
Path
from
subprocess
import
PIPE
,
Popen
from
unittest
import
TestCase
import
mmengine
import
torch
from
mmengine.config
import
Config
from
mmpretrain
import
ModelHub
,
get_model
from
mmpretrain.structures
import
DataSample
MMPRE_ROOT
=
Path
(
__file__
).
parent
.
parent
ASSETS_ROOT
=
Path
(
__file__
).
parent
/
'data'
class
TestImageDemo
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'demo/image_demo.py'
,
'demo/demo.JPEG'
,
'mobilevit-xxsmall_3rdparty_in1k'
,
'--device'
,
'cpu'
,
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
'sea snake'
,
out
.
decode
())
class
TestAnalyzeLogs
(
TestCase
):
def
setUp
(
self
):
self
.
log_file
=
ASSETS_ROOT
/
'vis_data.json'
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
out_file
=
Path
(
self
.
tmpdir
.
name
)
/
'out.png'
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/analysis_tools/analyze_logs.py'
,
'cal_train_time'
,
str
(
self
.
log_file
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
'slowest epoch 2, average time is 0.0219'
,
out
.
decode
())
command
=
[
'python'
,
'tools/analysis_tools/analyze_logs.py'
,
'plot_curve'
,
str
(
self
.
log_file
),
'--keys'
,
'accuracy/top1'
,
'--out'
,
str
(
self
.
out_file
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
str
(
self
.
log_file
),
out
.
decode
())
self
.
assertIn
(
str
(
self
.
out_file
),
out
.
decode
())
self
.
assertTrue
(
self
.
out_file
.
exists
())
class
TestAnalyzeResults
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
dataset_cfg
=
dict
(
type
=
'CustomDataset'
,
data_root
=
str
(
ASSETS_ROOT
/
'dataset'
),
)
config
=
Config
(
dict
(
test_dataloader
=
dict
(
dataset
=
dataset_cfg
)))
self
.
config_file
=
self
.
dir
/
'config.py'
config
.
dump
(
self
.
config_file
)
results
=
[{
'gt_label'
:
1
,
'pred_label'
:
0
,
'pred_score'
:
[
0.9
,
0.1
],
'sample_idx'
:
0
,
},
{
'gt_label'
:
0
,
'pred_label'
:
0
,
'pred_score'
:
[
0.9
,
0.1
],
'sample_idx'
:
1
,
}]
self
.
result_file
=
self
.
dir
/
'results.pkl'
mmengine
.
dump
(
results
,
self
.
result_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/analysis_tools/analyze_results.py'
,
str
(
self
.
config_file
),
str
(
self
.
result_file
),
'--out-dir'
,
str
(
self
.
tmpdir
.
name
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
p
.
communicate
()
self
.
assertTrue
((
self
.
dir
/
'success/2.jpeg.png'
).
exists
())
self
.
assertTrue
((
self
.
dir
/
'fail/1.JPG.png'
).
exists
())
class
TestPrintConfig
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
config_file
=
MMPRE_ROOT
/
'configs/resnet/resnet18_8xb32_in1k.py'
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/misc/print_config.py'
,
str
(
self
.
config_file
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
out
=
out
.
decode
().
strip
().
replace
(
'
\r\n
'
,
'
\n
'
)
self
.
assertEqual
(
out
,
Config
.
fromfile
(
self
.
config_file
).
pretty_text
.
strip
())
class
TestVerifyDataset
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
dataset_cfg
=
dict
(
type
=
'CustomDataset'
,
ann_file
=
str
(
self
.
dir
/
'ann.txt'
),
pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
)],
data_root
=
str
(
ASSETS_ROOT
/
'dataset'
),
)
ann_file
=
'
\n
'
.
join
([
'a/2.JPG 0'
,
'b/2.jpeg 1'
,
'b/subb/3.jpg 1'
])
(
self
.
dir
/
'ann.txt'
).
write_text
(
ann_file
)
config
=
Config
(
dict
(
train_dataloader
=
dict
(
dataset
=
dataset_cfg
)))
self
.
config_file
=
Path
(
self
.
tmpdir
.
name
)
/
'config.py'
config
.
dump
(
self
.
config_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/misc/verify_dataset.py'
,
str
(
self
.
config_file
),
'--out-path'
,
str
(
self
.
dir
/
'log.log'
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
f
"
{
ASSETS_ROOT
/
'dataset/a/2.JPG'
}
cannot be read correctly"
,
out
.
decode
().
strip
())
self
.
assertTrue
((
self
.
dir
/
'log.log'
).
exists
())
class
TestEvalMetric
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
results
=
[
DataSample
().
set_gt_label
(
1
).
set_pred_label
(
0
).
set_pred_score
(
[
0.6
,
0.3
,
0.1
]).
to_dict
(),
DataSample
().
set_gt_label
(
0
).
set_pred_label
(
0
).
set_pred_score
(
[
0.6
,
0.3
,
0.1
]).
to_dict
(),
]
self
.
result_file
=
self
.
dir
/
'results.pkl'
mmengine
.
dump
(
results
,
self
.
result_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/analysis_tools/eval_metric.py'
,
str
(
self
.
result_file
),
'--metric'
,
'type=Accuracy'
,
'topk=1,2'
,
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
'accuracy/top1'
,
out
.
decode
())
self
.
assertIn
(
'accuracy/top2'
,
out
.
decode
())
class
TestVisScheduler
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
config
=
Config
.
fromfile
(
MMPRE_ROOT
/
'configs/resnet/resnet18_8xb32_in1k.py'
)
config
.
param_scheduler
=
[
dict
(
type
=
'LinearLR'
,
start_factor
=
0.01
,
by_epoch
=
True
,
end
=
1
,
convert_to_iter_based
=
True
),
dict
(
type
=
'CosineAnnealingLR'
,
by_epoch
=
True
,
begin
=
1
),
]
config
.
work_dir
=
str
(
self
.
dir
)
config
.
train_cfg
.
max_epochs
=
2
self
.
config_file
=
Path
(
self
.
tmpdir
.
name
)
/
'config.py'
config
.
dump
(
self
.
config_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/visualization/vis_scheduler.py'
,
str
(
self
.
config_file
),
'--dataset-size'
,
'100'
,
'--not-show'
,
'--save-path'
,
str
(
self
.
dir
/
'out.png'
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
p
.
communicate
()
self
.
assertTrue
((
self
.
dir
/
'out.png'
).
exists
())
class
TestPublishModel
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
ckpt
=
dict
(
state_dict
=
OrderedDict
({
'a'
:
torch
.
tensor
(
1.
),
}),
ema_state_dict
=
OrderedDict
({
'step'
:
1
,
'module.a'
:
torch
.
tensor
(
2.
),
}))
self
.
ckpt_file
=
self
.
dir
/
'ckpt.pth'
torch
.
save
(
ckpt
,
self
.
ckpt_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/model_converters/publish_model.py'
,
str
(
self
.
ckpt_file
),
str
(
self
.
ckpt_file
),
'--dataset-type'
,
'ImageNet'
,
'--no-ema'
,
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
'and drop the EMA weights.'
,
out
.
decode
())
self
.
assertIn
(
'Successfully generated'
,
out
.
decode
())
output_ckpt
=
re
.
findall
(
r
'ckpt_\d{8}-\w{8}.pth'
,
out
.
decode
())
self
.
assertGreater
(
len
(
output_ckpt
),
0
)
output_ckpt
=
output_ckpt
[
0
]
self
.
assertTrue
((
self
.
dir
/
output_ckpt
).
exists
())
# The input file won't be overridden.
self
.
assertTrue
(
self
.
ckpt_file
.
exists
())
class
TestVisCam
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
model
=
get_model
(
'mobilevit-xxsmall_3rdparty_in1k'
)
self
.
config_file
=
self
.
dir
/
'config.py'
model
.
_config
.
dump
(
self
.
config_file
)
self
.
ckpt_file
=
self
.
dir
/
'ckpt.pth'
torch
.
save
(
model
.
state_dict
(),
self
.
ckpt_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/visualization/vis_cam.py'
,
str
(
ASSETS_ROOT
/
'color.jpg'
),
str
(
self
.
config_file
),
str
(
self
.
ckpt_file
),
'--save-path'
,
str
(
self
.
dir
/
'cam.jpg'
),
]
p
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
)
out
,
_
=
p
.
communicate
()
self
.
assertIn
(
'backbone.conv_1x1_exp.bn'
,
out
.
decode
())
self
.
assertTrue
((
self
.
dir
/
'cam.jpg'
).
exists
())
class
TestConfusionMatrix
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
self
.
config_file
=
MMPRE_ROOT
/
'configs/resnet/resnet18_8xb32_in1k.py'
results
=
[
DataSample
().
set_gt_label
(
1
).
set_pred_label
(
0
).
set_pred_score
(
[
0.6
,
0.3
,
0.1
]).
to_dict
(),
DataSample
().
set_gt_label
(
0
).
set_pred_label
(
0
).
set_pred_score
(
[
0.6
,
0.3
,
0.1
]).
to_dict
(),
]
self
.
result_file
=
self
.
dir
/
'results.pkl'
mmengine
.
dump
(
results
,
self
.
result_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/analysis_tools/confusion_matrix.py'
,
str
(
self
.
config_file
),
str
(
self
.
result_file
),
'--out'
,
str
(
self
.
dir
/
'result.pkl'
),
]
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
).
wait
()
result
=
mmengine
.
load
(
self
.
dir
/
'result.pkl'
)
torch
.
testing
.
assert_allclose
(
result
,
torch
.
tensor
([
[
1
,
0
,
0
],
[
1
,
0
,
0
],
[
0
,
0
,
0
],
]))
class
TestVisTsne
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
dir
=
Path
(
self
.
tmpdir
.
name
)
config
=
ModelHub
.
get
(
'mobilevit-xxsmall_3rdparty_in1k'
).
config
test_dataloader
=
dict
(
batch_size
=
1
,
dataset
=
dict
(
type
=
'CustomDataset'
,
data_root
=
str
(
ASSETS_ROOT
/
'dataset'
),
pipeline
=
config
.
test_dataloader
.
dataset
.
pipeline
,
),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
)
config
.
test_dataloader
=
mmengine
.
ConfigDict
(
test_dataloader
)
self
.
config_file
=
self
.
dir
/
'config.py'
config
.
dump
(
self
.
config_file
)
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
def
test_run
(
self
):
command
=
[
'python'
,
'tools/visualization/vis_tsne.py'
,
str
(
self
.
config_file
),
'--work-dir'
,
str
(
self
.
dir
),
'--perplexity'
,
'2'
,
]
Popen
(
command
,
cwd
=
MMPRE_ROOT
,
stdout
=
PIPE
).
wait
()
self
.
assertTrue
(
len
(
list
(
self
.
dir
.
glob
(
'tsne_*/feat_*.png'
)))
>
0
)
class
TestGetFlops
(
TestCase
):
def
test_run
(
self
):
command
=
[
'python'
,
'tools/analysis_tools/get_flops.py'
,
'mobilevit-xxsmall_3rdparty_in1k'
,
]
ret_code
=
Popen
(
command
,
cwd
=
MMPRE_ROOT
).
wait
()
self
.
assertEqual
(
ret_code
,
0
)
tests/test_utils/test_analyze.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
tempfile
from
mmpretrain.utils
import
load_json_log
def
test_load_json_log
():
demo_log
=
"""
\
{"lr": 0.0001, "data_time": 0.003, "loss": 2.29, "time": 0.010, "epoch": 1, "step": 150}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.28, "time": 0.007, "epoch": 1, "step": 300}
{"lr": 0.0001, "data_time": 0.001, "loss": 2.27, "time": 0.008, "epoch": 1, "step": 450}
{"accuracy/top1": 23.98, "accuracy/top5": 66.05, "step": 1}
{"lr": 0.0001, "data_time": 0.001, "loss": 2.25, "time": 0.014, "epoch": 2, "step": 619}
{"lr": 0.0001, "data_time": 0.000, "loss": 2.24, "time": 0.012, "epoch": 2, "step": 769}
{"lr": 0.0001, "data_time": 0.003, "loss": 2.23, "time": 0.009, "epoch": 2, "step": 919}
{"accuracy/top1": 41.82, "accuracy/top5": 81.26, "step": 2}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.21, "time": 0.007, "epoch": 3, "step": 1088}
{"lr": 0.0001, "data_time": 0.005, "loss": 2.18, "time": 0.009, "epoch": 3, "step": 1238}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.16, "time": 0.008, "epoch": 3, "step": 1388}
{"accuracy/top1": 54.07, "accuracy/top5": 89.80, "step": 3}
"""
# noqa: E501
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
json_log
=
osp
.
join
(
tmpdir
,
'scalars.json'
)
with
open
(
json_log
,
'w'
)
as
f
:
f
.
write
(
demo_log
)
log_dict
=
load_json_log
(
json_log
)
assert
log_dict
.
keys
()
==
{
'train'
,
'val'
}
assert
log_dict
[
'train'
][
3
]
==
{
'lr'
:
0.0001
,
'data_time'
:
0.001
,
'loss'
:
2.25
,
'time'
:
0.014
,
'epoch'
:
2
,
'step'
:
619
}
assert
log_dict
[
'val'
][
2
]
==
{
'accuracy/top1'
:
54.07
,
'accuracy/top5'
:
89.80
,
'step'
:
3
}
tests/test_utils/test_setup_env.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
datetime
import
sys
from
unittest
import
TestCase
from
mmengine
import
DefaultScope
from
mmpretrain.utils
import
register_all_modules
class
TestSetupEnv
(
TestCase
):
def
test_register_all_modules
(
self
):
from
mmpretrain.registry
import
DATASETS
# not init default scope
sys
.
modules
.
pop
(
'mmpretrain.datasets'
,
None
)
sys
.
modules
.
pop
(
'mmpretrain.datasets.custom'
,
None
)
DATASETS
.
_module_dict
.
pop
(
'CustomDataset'
,
None
)
self
.
assertFalse
(
'CustomDataset'
in
DATASETS
.
module_dict
)
register_all_modules
(
init_default_scope
=
False
)
self
.
assertTrue
(
'CustomDataset'
in
DATASETS
.
module_dict
)
# init default scope
sys
.
modules
.
pop
(
'mmpretrain.datasets'
)
sys
.
modules
.
pop
(
'mmpretrain.datasets.custom'
)
DATASETS
.
_module_dict
.
pop
(
'CustomDataset'
,
None
)
self
.
assertFalse
(
'CustomDataset'
in
DATASETS
.
module_dict
)
register_all_modules
(
init_default_scope
=
True
)
self
.
assertTrue
(
'CustomDataset'
in
DATASETS
.
module_dict
)
self
.
assertEqual
(
DefaultScope
.
get_current_instance
().
scope_name
,
'mmpretrain'
)
# init default scope when another scope is init
name
=
f
'test-
{
datetime
.
datetime
.
now
()
}
'
DefaultScope
.
get_instance
(
name
,
scope_name
=
'test'
)
with
self
.
assertWarnsRegex
(
Warning
,
'The current default scope "test" is not "mmpretrain"'
):
register_all_modules
(
init_default_scope
=
True
)
tests/test_utils/test_version_utils.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
mmpretrain
import
digit_version
def
test_digit_version
():
assert
digit_version
(
'0.2.16'
)
==
(
0
,
2
,
16
,
0
,
0
,
0
)
assert
digit_version
(
'1.2.3'
)
==
(
1
,
2
,
3
,
0
,
0
,
0
)
assert
digit_version
(
'1.2.3rc0'
)
==
(
1
,
2
,
3
,
0
,
-
1
,
0
)
assert
digit_version
(
'1.2.3rc1'
)
==
(
1
,
2
,
3
,
0
,
-
1
,
1
)
assert
digit_version
(
'1.0rc0'
)
==
(
1
,
0
,
0
,
0
,
-
1
,
0
)
assert
digit_version
(
'1.0'
)
==
digit_version
(
'1.0.0'
)
assert
digit_version
(
'1.5.0+cuda90_cudnn7.6.3_lms'
)
==
digit_version
(
'1.5'
)
assert
digit_version
(
'1.0.0dev'
)
<
digit_version
(
'1.0.0a'
)
assert
digit_version
(
'1.0.0a'
)
<
digit_version
(
'1.0.0a1'
)
assert
digit_version
(
'1.0.0a'
)
<
digit_version
(
'1.0.0b'
)
assert
digit_version
(
'1.0.0b'
)
<
digit_version
(
'1.0.0rc'
)
assert
digit_version
(
'1.0.0rc1'
)
<
digit_version
(
'1.0.0'
)
assert
digit_version
(
'1.0.0'
)
<
digit_version
(
'1.0.0post'
)
assert
digit_version
(
'1.0.0post'
)
<
digit_version
(
'1.0.0post1'
)
assert
digit_version
(
'v1'
)
==
(
1
,
0
,
0
,
0
,
0
,
0
)
assert
digit_version
(
'v1.1.5'
)
==
(
1
,
1
,
5
,
0
,
0
,
0
)
tests/test_visualization/test_visualizer.py
0 → 100644
View file @
1baf0566
# Copyright (c) Open-MMLab. All rights reserved.
import
os.path
as
osp
import
tempfile
from
unittest
import
TestCase
from
unittest.mock
import
patch
import
numpy
as
np
import
torch
from
mmpretrain.structures
import
DataSample
from
mmpretrain.visualization
import
UniversalVisualizer
class
TestUniversalVisualizer
(
TestCase
):
def
setUp
(
self
)
->
None
:
super
().
setUp
()
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
tmpdir
=
tmpdir
self
.
vis
=
UniversalVisualizer
(
save_dir
=
tmpdir
.
name
,
vis_backends
=
[
dict
(
type
=
'LocalVisBackend'
)],
)
def
test_visualize_cls
(
self
):
image
=
np
.
ones
((
10
,
10
,
3
),
np
.
uint8
)
data_sample
=
DataSample
().
set_gt_label
(
1
).
set_pred_label
(
1
).
\
set_pred_score
(
torch
.
tensor
([
0.1
,
0.8
,
0.1
]))
# Test show
def
mock_show
(
drawn_img
,
win_name
,
wait_time
):
self
.
assertFalse
((
image
==
drawn_img
).
all
())
self
.
assertEqual
(
win_name
,
'test_cls'
)
self
.
assertEqual
(
wait_time
,
0
)
with
patch
.
object
(
self
.
vis
,
'show'
,
mock_show
):
self
.
vis
.
visualize_cls
(
image
=
image
,
data_sample
=
data_sample
,
show
=
True
,
name
=
'test_cls'
,
step
=
1
)
# Test storage backend.
save_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'vis_data/vis_image/test_cls_1.png'
)
self
.
assertTrue
(
osp
.
exists
(
save_file
))
# Test out_file
out_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'results.png'
)
self
.
vis
.
visualize_cls
(
image
=
image
,
data_sample
=
data_sample
,
out_file
=
out_file
)
self
.
assertTrue
(
osp
.
exists
(
out_file
))
# Test with dataset_meta
self
.
vis
.
dataset_meta
=
{
'classes'
:
[
'cat'
,
'bird'
,
'dog'
]}
def
patch_texts
(
text
,
*
_
,
**
__
):
self
.
assertEqual
(
text
,
'
\n
'
.
join
([
'Ground truth: 1 (bird)'
,
'Prediction: 1, 0.80 (bird)'
,
]))
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
patch_texts
):
self
.
vis
.
visualize_cls
(
image
,
data_sample
)
# Test without pred_label
def
patch_texts
(
text
,
*
_
,
**
__
):
self
.
assertEqual
(
text
,
'
\n
'
.
join
([
'Ground truth: 1 (bird)'
,
]))
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
patch_texts
):
self
.
vis
.
visualize_cls
(
image
,
data_sample
,
draw_pred
=
False
)
# Test without gt_label
def
patch_texts
(
text
,
*
_
,
**
__
):
self
.
assertEqual
(
text
,
'
\n
'
.
join
([
'Prediction: 1, 0.80 (bird)'
,
]))
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
patch_texts
):
self
.
vis
.
visualize_cls
(
image
,
data_sample
,
draw_gt
=
False
)
# Test without score
del
data_sample
.
pred_score
def
patch_texts
(
text
,
*
_
,
**
__
):
self
.
assertEqual
(
text
,
'
\n
'
.
join
([
'Ground truth: 1 (bird)'
,
'Prediction: 1 (bird)'
,
]))
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
patch_texts
):
self
.
vis
.
visualize_cls
(
image
,
data_sample
)
# Test adaptive font size
def
assert_font_size
(
target_size
):
def
draw_texts
(
text
,
font_sizes
,
*
_
,
**
__
):
self
.
assertEqual
(
font_sizes
,
target_size
)
return
draw_texts
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
assert_font_size
(
7
)):
self
.
vis
.
visualize_cls
(
np
.
ones
((
224
,
384
,
3
),
np
.
uint8
),
data_sample
)
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
assert_font_size
(
2
)):
self
.
vis
.
visualize_cls
(
np
.
ones
((
10
,
384
,
3
),
np
.
uint8
),
data_sample
)
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
assert_font_size
(
21
)):
self
.
vis
.
visualize_cls
(
np
.
ones
((
1000
,
1000
,
3
),
np
.
uint8
),
data_sample
)
# Test rescale image
with
patch
.
object
(
self
.
vis
,
'draw_texts'
,
assert_font_size
(
14
)):
self
.
vis
.
visualize_cls
(
np
.
ones
((
224
,
384
,
3
),
np
.
uint8
),
data_sample
,
rescale_factor
=
2.
)
def
test_visualize_image_retrieval
(
self
):
image
=
np
.
ones
((
10
,
10
,
3
),
np
.
uint8
)
data_sample
=
DataSample
().
set_pred_score
([
0.1
,
0.8
,
0.1
])
class
ToyPrototype
:
def
get_data_info
(
self
,
idx
):
img_path
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../data/color.jpg'
)
return
{
'img_path'
:
img_path
,
'sample_idx'
:
idx
}
prototype_dataset
=
ToyPrototype
()
# Test show
def
mock_show
(
drawn_img
,
win_name
,
wait_time
):
if
image
.
shape
==
drawn_img
.
shape
:
self
.
assertFalse
((
image
==
drawn_img
).
all
())
self
.
assertEqual
(
win_name
,
'test_retrieval'
)
self
.
assertEqual
(
wait_time
,
0
)
with
patch
.
object
(
self
.
vis
,
'show'
,
mock_show
):
self
.
vis
.
visualize_image_retrieval
(
image
,
data_sample
,
prototype_dataset
,
show
=
True
,
name
=
'test_retrieval'
,
step
=
1
)
# Test storage backend.
save_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'vis_data/vis_image/test_retrieval_1.png'
)
self
.
assertTrue
(
osp
.
exists
(
save_file
))
# Test out_file
out_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'results.png'
)
self
.
vis
.
visualize_image_retrieval
(
image
,
data_sample
,
prototype_dataset
,
out_file
=
out_file
,
)
self
.
assertTrue
(
osp
.
exists
(
out_file
))
def
test_visualize_masked_image
(
self
):
image
=
np
.
ones
((
10
,
10
,
3
),
np
.
uint8
)
data_sample
=
DataSample
().
set_mask
(
torch
.
tensor
([
[
0
,
0
,
1
,
1
],
[
0
,
1
,
1
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
],
]))
# Test show
def
mock_show
(
drawn_img
,
win_name
,
wait_time
):
self
.
assertTupleEqual
(
drawn_img
.
shape
,
(
224
,
224
,
3
))
self
.
assertEqual
(
win_name
,
'test_mask'
)
self
.
assertEqual
(
wait_time
,
0
)
with
patch
.
object
(
self
.
vis
,
'show'
,
mock_show
):
self
.
vis
.
visualize_masked_image
(
image
,
data_sample
,
show
=
True
,
name
=
'test_mask'
,
step
=
1
)
# Test storage backend.
save_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'vis_data/vis_image/test_mask_1.png'
)
self
.
assertTrue
(
osp
.
exists
(
save_file
))
# Test out_file
out_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'results.png'
)
self
.
vis
.
visualize_masked_image
(
image
,
data_sample
,
out_file
=
out_file
)
self
.
assertTrue
(
osp
.
exists
(
out_file
))
def
tearDown
(
self
):
self
.
tmpdir
.
cleanup
()
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment