Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
1ac2e802
Commit
1ac2e802
authored
Jun 24, 2025
by
limm
Browse files
add tools code
parent
b6df0d33
Pipeline
#2803
canceled with stages
Changes
71
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1473 additions
and
0 deletions
+1473
-0
tools/analysis_tools/analyze_logs.py
tools/analysis_tools/analyze_logs.py
+218
-0
tools/analysis_tools/analyze_results.py
tools/analysis_tools/analyze_results.py
+121
-0
tools/analysis_tools/confusion_matrix.py
tools/analysis_tools/confusion_matrix.py
+108
-0
tools/analysis_tools/eval_metric.py
tools/analysis_tools/eval_metric.py
+62
-0
tools/analysis_tools/get_flops.py
tools/analysis_tools/get_flops.py
+61
-0
tools/analysis_tools/shape_bias.py
tools/analysis_tools/shape_bias.py
+284
-0
tools/analysis_tools/utils.py
tools/analysis_tools/utils.py
+277
-0
tools/benchmarks/mmdetection/mim_dist_test.sh
tools/benchmarks/mmdetection/mim_dist_test.sh
+16
-0
tools/benchmarks/mmdetection/mim_dist_train_c4.sh
tools/benchmarks/mmdetection/mim_dist_train_c4.sh
+19
-0
tools/benchmarks/mmdetection/mim_dist_train_fpn.sh
tools/benchmarks/mmdetection/mim_dist_train_fpn.sh
+16
-0
tools/benchmarks/mmdetection/mim_slurm_test.sh
tools/benchmarks/mmdetection/mim_slurm_test.sh
+23
-0
tools/benchmarks/mmdetection/mim_slurm_train_c4.sh
tools/benchmarks/mmdetection/mim_slurm_train_c4.sh
+27
-0
tools/benchmarks/mmdetection/mim_slurm_train_fpn.sh
tools/benchmarks/mmdetection/mim_slurm_train_fpn.sh
+24
-0
tools/benchmarks/mmsegmentation/mim_dist_test.sh
tools/benchmarks/mmsegmentation/mim_dist_test.sh
+16
-0
tools/benchmarks/mmsegmentation/mim_dist_train.sh
tools/benchmarks/mmsegmentation/mim_dist_train.sh
+17
-0
tools/benchmarks/mmsegmentation/mim_slurm_test.sh
tools/benchmarks/mmsegmentation/mim_slurm_test.sh
+23
-0
tools/benchmarks/mmsegmentation/mim_slurm_train.sh
tools/benchmarks/mmsegmentation/mim_slurm_train.sh
+25
-0
tools/dataset_converters/convert_flickr30k_ann.py
tools/dataset_converters/convert_flickr30k_ann.py
+56
-0
tools/dataset_converters/convert_imagenet_subsets.py
tools/dataset_converters/convert_imagenet_subsets.py
+48
-0
tools/dataset_converters/convert_inaturalist.py
tools/dataset_converters/convert_inaturalist.py
+32
-0
No files found.
tools/analysis_tools/analyze_logs.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
re
from
itertools
import
groupby
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
mmpretrain.utils
import
load_json_log
def
cal_train_time
(
log_dicts
,
args
):
"""Compute the average time per training iteration."""
for
i
,
log_dict
in
enumerate
(
log_dicts
):
print
(
f
'
{
"-"
*
5
}
Analyze train time of
{
args
.
json_logs
[
i
]
}{
"-"
*
5
}
'
)
train_logs
=
log_dict
[
'train'
]
if
'epoch'
in
train_logs
[
0
]:
epoch_ave_times
=
[]
for
_
,
logs
in
groupby
(
train_logs
,
lambda
log
:
log
[
'epoch'
]):
if
args
.
include_outliers
:
all_time
=
np
.
array
([
log
[
'time'
]
for
log
in
logs
])
else
:
all_time
=
np
.
array
([
log
[
'time'
]
for
log
in
logs
])[
1
:]
epoch_ave_times
.
append
(
all_time
.
mean
())
epoch_ave_times
=
np
.
array
(
epoch_ave_times
)
slowest_epoch
=
epoch_ave_times
.
argmax
()
fastest_epoch
=
epoch_ave_times
.
argmin
()
std_over_epoch
=
epoch_ave_times
.
std
()
print
(
f
'slowest epoch
{
slowest_epoch
+
1
}
, '
f
'average time is
{
epoch_ave_times
[
slowest_epoch
]:.
4
f
}
'
)
print
(
f
'fastest epoch
{
fastest_epoch
+
1
}
, '
f
'average time is
{
epoch_ave_times
[
fastest_epoch
]:.
4
f
}
'
)
print
(
f
'time std over epochs is
{
std_over_epoch
:.
4
f
}
'
)
avg_iter_time
=
np
.
array
([
log
[
'time'
]
for
log
in
train_logs
]).
mean
()
print
(
f
'average iter time:
{
avg_iter_time
:.
4
f
}
s/iter'
)
print
()
def
get_legends
(
args
):
"""if legend is None, use {filename}_{key} as legend."""
legend
=
args
.
legend
if
legend
is
None
:
legend
=
[]
for
json_log
in
args
.
json_logs
:
for
metric
in
args
.
keys
:
# remove '.json' in the end of log names
basename
=
os
.
path
.
basename
(
json_log
)[:
-
5
]
if
basename
.
endswith
(
'.log'
):
basename
=
basename
[:
-
4
]
legend
.
append
(
f
'
{
basename
}
_
{
metric
}
'
)
assert
len
(
legend
)
==
(
len
(
args
.
json_logs
)
*
len
(
args
.
keys
))
return
legend
def
plot_phase_train
(
metric
,
train_logs
,
curve_label
):
"""plot phase of train curve."""
xs
=
np
.
array
([
log
[
'step'
]
for
log
in
train_logs
])
ys
=
np
.
array
([
log
[
metric
]
for
log
in
train_logs
])
if
'epoch'
in
train_logs
[
0
]:
scale_factor
=
train_logs
[
-
1
][
'step'
]
/
train_logs
[
-
1
][
'epoch'
]
xs
=
xs
/
scale_factor
plt
.
xlabel
(
'Epochs'
)
else
:
plt
.
xlabel
(
'Iters'
)
plt
.
plot
(
xs
,
ys
,
label
=
curve_label
,
linewidth
=
0.75
)
def
plot_phase_val
(
metric
,
val_logs
,
curve_label
):
"""plot phase of val curve."""
xs
=
np
.
array
([
log
[
'step'
]
for
log
in
val_logs
])
ys
=
np
.
array
([
log
[
metric
]
for
log
in
val_logs
])
plt
.
xlabel
(
'Steps'
)
plt
.
plot
(
xs
,
ys
,
label
=
curve_label
,
linewidth
=
0.75
)
def
plot_curve_helper
(
log_dicts
,
metrics
,
args
,
legend
):
"""plot curves from log_dicts by metrics."""
num_metrics
=
len
(
metrics
)
for
i
,
log_dict
in
enumerate
(
log_dicts
):
for
j
,
key
in
enumerate
(
metrics
):
json_log
=
args
.
json_logs
[
i
]
print
(
f
'plot curve of
{
json_log
}
, metric is
{
key
}
'
)
curve_label
=
legend
[
i
*
num_metrics
+
j
]
train_keys
=
{}
if
len
(
log_dict
[
'train'
])
==
0
else
set
(
log_dict
[
'train'
][
0
].
keys
())
-
{
'step'
,
'epoch'
}
val_keys
=
{}
if
len
(
log_dict
[
'val'
])
==
0
else
set
(
log_dict
[
'val'
][
0
].
keys
())
-
{
'step'
}
if
key
in
val_keys
:
plot_phase_val
(
key
,
log_dict
[
'val'
],
curve_label
)
elif
key
in
train_keys
:
plot_phase_train
(
key
,
log_dict
[
'train'
],
curve_label
)
else
:
raise
ValueError
(
f
'Invalid key "
{
key
}
", please choose from '
f
'
{
set
.
union
(
set
(
train_keys
),
set
(
val_keys
))
}
.'
)
plt
.
legend
()
def
plot_curve
(
log_dicts
,
args
):
"""Plot train metric-iter graph."""
# set style
try
:
import
seaborn
as
sns
sns
.
set_style
(
args
.
style
)
except
ImportError
:
pass
# set plot window size
wind_w
,
wind_h
=
args
.
window_size
.
split
(
'*'
)
wind_w
,
wind_h
=
int
(
wind_w
),
int
(
wind_h
)
plt
.
figure
(
figsize
=
(
wind_w
,
wind_h
))
# get legends and metrics
legends
=
get_legends
(
args
)
metrics
=
args
.
keys
# plot curves from log_dicts by metrics
plot_curve_helper
(
log_dicts
,
metrics
,
args
,
legends
)
# set title and show or save
if
args
.
title
is
not
None
:
plt
.
title
(
args
.
title
)
if
args
.
out
is
None
:
plt
.
show
()
else
:
print
(
f
'save curve to:
{
args
.
out
}
'
)
plt
.
savefig
(
args
.
out
)
plt
.
cla
()
def
add_plot_parser
(
subparsers
):
parser_plt
=
subparsers
.
add_parser
(
'plot_curve'
,
help
=
'parser for plotting curves'
)
parser_plt
.
add_argument
(
'json_logs'
,
type
=
str
,
nargs
=
'+'
,
help
=
'path of train log in json format'
)
parser_plt
.
add_argument
(
'--keys'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
'loss'
],
help
=
'the metric that you want to plot'
)
parser_plt
.
add_argument
(
'--title'
,
type
=
str
,
help
=
'title of figure'
)
parser_plt
.
add_argument
(
'--legend'
,
type
=
str
,
nargs
=
'+'
,
default
=
None
,
help
=
'legend of each plot'
)
parser_plt
.
add_argument
(
'--style'
,
type
=
str
,
default
=
'whitegrid'
,
help
=
'style of the figure, need `seaborn` package.'
)
parser_plt
.
add_argument
(
'--out'
,
type
=
str
,
default
=
None
)
parser_plt
.
add_argument
(
'--window-size'
,
default
=
'12*7'
,
help
=
'size of the window to display images, in format of "$W*$H".'
)
def
add_time_parser
(
subparsers
):
parser_time
=
subparsers
.
add_parser
(
'cal_train_time'
,
help
=
'parser for computing the average time per training iteration'
)
parser_time
.
add_argument
(
'json_logs'
,
type
=
str
,
nargs
=
'+'
,
help
=
'path of train log in json format'
)
parser_time
.
add_argument
(
'--include-outliers'
,
action
=
'store_true'
,
help
=
'include the first value of every epoch when computing '
'the average time'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Analyze Json Log'
)
# currently only support plot curve and calculate average train time
subparsers
=
parser
.
add_subparsers
(
dest
=
'task'
,
help
=
'task parser'
)
add_plot_parser
(
subparsers
)
add_time_parser
(
subparsers
)
args
=
parser
.
parse_args
()
if
hasattr
(
args
,
'window_size'
)
and
args
.
window_size
!=
''
:
assert
re
.
match
(
r
'\d+\*\d+'
,
args
.
window_size
),
\
"'window-size' must be in format 'W*H'."
return
args
def
main
():
args
=
parse_args
()
json_logs
=
args
.
json_logs
for
json_log
in
json_logs
:
assert
json_log
.
endswith
(
'.json'
)
log_dicts
=
[
load_json_log
(
json_log
)
for
json_log
in
json_logs
]
if
args
.
task
==
'cal_train_time'
:
cal_train_time
(
log_dicts
,
args
)
elif
args
.
task
==
'plot_curve'
:
plot_curve
(
log_dicts
,
args
)
if
__name__
==
'__main__'
:
main
()
tools/analysis_tools/analyze_results.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
pathlib
import
Path
import
mmcv
import
mmengine
import
torch
from
mmengine
import
DictAction
from
mmpretrain.datasets
import
build_dataset
from
mmpretrain.structures
import
DataSample
from
mmpretrain.visualization
import
UniversalVisualizer
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMPreTrain evaluate prediction success/fail'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'result'
,
help
=
'test result json/pkl file'
)
parser
.
add_argument
(
'--out-dir'
,
required
=
True
,
help
=
'dir to store output files'
)
parser
.
add_argument
(
'--topk'
,
default
=
20
,
type
=
int
,
help
=
'Number of images to select for success/fail'
)
parser
.
add_argument
(
'--rescale-factor'
,
'-r'
,
type
=
float
,
help
=
'image rescale factor, which is useful if the output is too '
'large or too small.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
args
=
parser
.
parse_args
()
return
args
def
save_imgs
(
result_dir
,
folder_name
,
results
,
dataset
,
rescale_factor
=
None
):
full_dir
=
osp
.
join
(
result_dir
,
folder_name
)
vis
=
UniversalVisualizer
()
vis
.
dataset_meta
=
{
'classes'
:
dataset
.
CLASSES
}
# save imgs
dump_infos
=
[]
for
data_sample
in
results
:
data_info
=
dataset
.
get_data_info
(
data_sample
.
sample_idx
)
if
'img'
in
data_info
:
img
=
data_info
[
'img'
]
name
=
str
(
data_sample
.
sample_idx
)
elif
'img_path'
in
data_info
:
img
=
mmcv
.
imread
(
data_info
[
'img_path'
],
channel_order
=
'rgb'
)
name
=
Path
(
data_info
[
'img_path'
]).
name
else
:
raise
ValueError
(
'Cannot load images from the dataset infos.'
)
if
rescale_factor
is
not
None
:
img
=
mmcv
.
imrescale
(
img
,
rescale_factor
)
vis
.
visualize_cls
(
img
,
data_sample
,
out_file
=
osp
.
join
(
full_dir
,
name
+
'.png'
))
dump
=
dict
()
for
k
,
v
in
data_sample
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
dump
[
k
]
=
v
.
tolist
()
else
:
dump
[
k
]
=
v
dump_infos
.
append
(
dump
)
mmengine
.
dump
(
dump_infos
,
osp
.
join
(
full_dir
,
folder_name
+
'.json'
))
def
main
():
args
=
parse_args
()
cfg
=
mmengine
.
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# build the dataloader
cfg
.
test_dataloader
.
dataset
.
pipeline
=
[]
dataset
=
build_dataset
(
cfg
.
test_dataloader
.
dataset
)
results
=
list
()
for
result
in
mmengine
.
load
(
args
.
result
):
data_sample
=
DataSample
()
data_sample
.
set_metainfo
({
'sample_idx'
:
result
[
'sample_idx'
]})
data_sample
.
set_gt_label
(
result
[
'gt_label'
])
data_sample
.
set_pred_label
(
result
[
'pred_label'
])
data_sample
.
set_pred_score
(
result
[
'pred_score'
])
results
.
append
(
data_sample
)
# sort result
results
=
sorted
(
results
,
key
=
lambda
x
:
torch
.
max
(
x
.
pred_score
))
success
=
list
()
fail
=
list
()
for
data_sample
in
results
:
if
(
data_sample
.
pred_label
==
data_sample
.
gt_label
).
all
():
success
.
append
(
data_sample
)
else
:
fail
.
append
(
data_sample
)
success
=
success
[:
args
.
topk
]
fail
=
fail
[:
args
.
topk
]
save_imgs
(
args
.
out_dir
,
'success'
,
success
,
dataset
,
args
.
rescale_factor
)
save_imgs
(
args
.
out_dir
,
'fail'
,
fail
,
dataset
,
args
.
rescale_factor
)
if
__name__
==
'__main__'
:
main
()
tools/analysis_tools/confusion_matrix.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
tempfile
import
mmengine
from
mmengine.config
import
Config
,
DictAction
from
mmengine.evaluator
import
Evaluator
from
mmengine.runner
import
Runner
from
mmpretrain.evaluation
import
ConfusionMatrix
from
mmpretrain.registry
import
DATASETS
from
mmpretrain.utils
import
register_all_modules
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Eval a checkpoint and draw the confusion matrix.'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'ckpt_or_result'
,
type
=
str
,
help
=
'The checkpoint file (.pth) or '
'dumpped predictions pickle file (.pkl).'
)
parser
.
add_argument
(
'--out'
,
help
=
'the file to save the confusion matrix.'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'whether to display the metric result by matplotlib if supports.'
)
parser
.
add_argument
(
'--show-path'
,
type
=
str
,
help
=
'Path to save the visualization image.'
)
parser
.
add_argument
(
'--include-values'
,
action
=
'store_true'
,
help
=
'To draw the values in the figure.'
)
parser
.
add_argument
(
'--cmap'
,
type
=
str
,
default
=
'viridis'
,
help
=
'The color map to use. Defaults to "viridis".'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# register all modules in mmpretrain into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules
(
init_default_scope
=
False
)
# load config
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
if
args
.
ckpt_or_result
.
endswith
(
'.pth'
):
# Set confusion matrix as the metric.
cfg
.
test_evaluator
=
dict
(
type
=
'ConfusionMatrix'
)
cfg
.
load_from
=
str
(
args
.
ckpt_or_result
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
cfg
.
work_dir
=
tmpdir
runner
=
Runner
.
from_cfg
(
cfg
)
classes
=
runner
.
test_loop
.
dataloader
.
dataset
.
metainfo
.
get
(
'classes'
)
cm
=
runner
.
test
()[
'confusion_matrix/result'
]
else
:
predictions
=
mmengine
.
load
(
args
.
ckpt_or_result
)
evaluator
=
Evaluator
(
ConfusionMatrix
())
metrics
=
evaluator
.
offline_evaluate
(
predictions
,
None
)
cm
=
metrics
[
'confusion_matrix/result'
]
try
:
# Try to build the dataset.
dataset
=
DATASETS
.
build
({
**
cfg
.
test_dataloader
.
dataset
,
'pipeline'
:
[]
})
classes
=
dataset
.
metainfo
.
get
(
'classes'
)
except
Exception
:
classes
=
None
if
args
.
out
is
not
None
:
mmengine
.
dump
(
cm
,
args
.
out
)
if
args
.
show
or
args
.
show_path
is
not
None
:
fig
=
ConfusionMatrix
.
plot
(
cm
,
show
=
args
.
show
,
classes
=
classes
,
include_values
=
args
.
include_values
,
cmap
=
args
.
cmap
)
if
args
.
show_path
is
not
None
:
fig
.
savefig
(
args
.
show_path
)
print
(
f
'The confusion matrix is saved at
{
args
.
show_path
}
.'
)
if
__name__
==
'__main__'
:
main
()
tools/analysis_tools/eval_metric.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
mmengine
import
rich
from
mmengine
import
DictAction
from
mmengine.evaluator
import
Evaluator
from
mmpretrain.registry
import
METRICS
HELP_URL
=
(
'https://mmpretrain.readthedocs.io/en/latest/useful_tools/'
'log_result_analysis.html#how-to-conduct-offline-metric-evaluation'
)
prog_description
=
f
"""
\
Evaluate metric of the results saved in pkl format.
The detailed usage can be found in
{
HELP_URL
}
"""
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
prog_description
)
parser
.
add_argument
(
'pkl_results'
,
help
=
'Results in pickle format'
)
parser
.
add_argument
(
'--metric'
,
nargs
=
'+'
,
action
=
'append'
,
dest
=
'metric_options'
,
help
=
'The metric config, the key-value pair in xxx=yyy format will be '
'parsed as the metric config items. You can specify multiple metrics '
'by use multiple `--metric`. For list type value, you can use '
'"key=[a,b]" or "key=a,b", and it also allows nested list/tuple '
'values, e.g. "key=[(a,b),(c,d)]".'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
if
args
.
metric_options
is
None
:
raise
ValueError
(
'Please speicfy at least one `--metric`. '
f
'The detailed usage can be found in
{
HELP_URL
}
'
)
test_metrics
=
[]
for
metric_option
in
args
.
metric_options
:
metric_cfg
=
{}
for
kv
in
metric_option
:
k
,
v
=
kv
.
split
(
'='
,
maxsplit
=
1
)
metric_cfg
[
k
]
=
DictAction
.
_parse_iterable
(
v
)
test_metrics
.
append
(
METRICS
.
build
(
metric_cfg
))
predictions
=
mmengine
.
load
(
args
.
pkl_results
)
evaluator
=
Evaluator
(
test_metrics
)
eval_results
=
evaluator
.
offline_evaluate
(
predictions
,
None
)
rich
.
print
(
eval_results
)
if
__name__
==
'__main__'
:
main
()
tools/analysis_tools/get_flops.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
mmengine.analysis
import
get_model_complexity_info
from
mmpretrain
import
get_model
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Get model flops and params'
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
parser
.
add_argument
(
'--shape'
,
type
=
int
,
nargs
=
'+'
,
default
=
[
224
,
224
],
help
=
'input image size'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
if
len
(
args
.
shape
)
==
1
:
input_shape
=
(
3
,
args
.
shape
[
0
],
args
.
shape
[
0
])
elif
len
(
args
.
shape
)
==
2
:
input_shape
=
(
3
,
)
+
tuple
(
args
.
shape
)
else
:
raise
ValueError
(
'invalid input shape'
)
model
=
get_model
(
args
.
config
)
model
.
eval
()
if
hasattr
(
model
,
'extract_feat'
):
model
.
forward
=
model
.
extract_feat
else
:
raise
NotImplementedError
(
'FLOPs counter is currently not currently supported with {}'
.
format
(
model
.
__class__
.
__name__
))
analysis_results
=
get_model_complexity_info
(
model
,
input_shape
,
)
flops
=
analysis_results
[
'flops_str'
]
params
=
analysis_results
[
'params_str'
]
activations
=
analysis_results
[
'activations_str'
]
out_table
=
analysis_results
[
'out_table'
]
out_arch
=
analysis_results
[
'out_arch'
]
print
(
out_arch
)
print
(
out_table
)
split_line
=
'='
*
30
print
(
f
'
{
split_line
}
\n
Input shape:
{
input_shape
}
\n
'
f
'Flops:
{
flops
}
\n
Params:
{
params
}
\n
'
f
'Activation:
{
activations
}
\n
{
split_line
}
'
)
print
(
'!!!Only the backbone network is counted in FLOPs analysis.'
)
print
(
'!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.'
)
if
__name__
==
'__main__'
:
main
()
tools/analysis_tools/shape_bias.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
import
argparse
import
os
import
os.path
as
osp
import
matplotlib
as
mpl
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
pandas
as
pd
from
mmengine.logging
import
MMLogger
from
utils
import
FormatStrFormatter
,
ShapeBias
# global default boundary settings for thin gray transparent
# boundaries to avoid not being able to see the difference
# between two partially overlapping datapoints of the same color:
PLOTTING_EDGE_COLOR
=
(
0.3
,
0.3
,
0.3
,
0.3
)
PLOTTING_EDGE_WIDTH
=
0.02
ICONS_DIR
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'..'
,
'..'
,
'resources'
,
'shape_bias_icons'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--csv-dir'
,
type
=
str
,
help
=
'directory of csv files'
)
parser
.
add_argument
(
'--result-dir'
,
type
=
str
,
help
=
'directory to save plotting results'
)
parser
.
add_argument
(
'--model-names'
,
nargs
=
'+'
,
default
=
[],
help
=
'model name'
)
parser
.
add_argument
(
'--colors'
,
nargs
=
'+'
,
type
=
float
,
default
=
[],
help
=
# noqa
'the colors for the plots of each model, and they should be in the same order as model_names'
# noqa: E501
)
parser
.
add_argument
(
'--markers'
,
nargs
=
'+'
,
type
=
str
,
default
=
[],
help
=
# noqa
'the markers for the plots of each model, and they should be in the same order as model_names'
# noqa: E501
)
parser
.
add_argument
(
'--plotting-names'
,
nargs
=
'+'
,
default
=
[],
help
=
# noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names'
# noqa: E501
)
parser
.
add_argument
(
'--delete-icons'
,
action
=
'store_true'
,
help
=
'whether to delete the icons after plotting'
)
humans
=
[
'subject-01'
,
'subject-02'
,
'subject-03'
,
'subject-04'
,
'subject-05'
,
'subject-06'
,
'subject-07'
,
'subject-08'
,
'subject-09'
,
'subject-10'
]
icon_names
=
[
'airplane.png'
,
'response_icons_vertical_reverse.png'
,
'bottle.png'
,
'car.png'
,
'oven.png'
,
'elephant.png'
,
'dog.png'
,
'boat.png'
,
'clock.png'
,
'chair.png'
,
'keyboard.png'
,
'bird.png'
,
'bicycle.png'
,
'response_icons_horizontal.png'
,
'cat.png'
,
'bear.png'
,
'colorbar.pdf'
,
'knife.png'
,
'response_icons_vertical.png'
,
'truck.png'
]
def
read_csvs
(
csv_dir
:
str
)
->
pd
.
DataFrame
:
"""Reads all csv files in a directory and returns a single dataframe.
Args:
csv_dir (str): directory of csv files.
Returns:
pd.DataFrame: dataframe containing all csv files
"""
df
=
pd
.
DataFrame
()
for
csv
in
os
.
listdir
(
csv_dir
):
if
csv
.
endswith
(
'.csv'
):
cur_df
=
pd
.
read_csv
(
osp
.
join
(
csv_dir
,
csv
))
cur_df
.
columns
=
[
c
.
lower
()
for
c
in
cur_df
.
columns
]
df
=
df
.
append
(
cur_df
)
df
.
condition
=
df
.
condition
.
astype
(
str
)
return
df
def
plot_shape_bias_matrixplot
(
args
,
analysis
=
ShapeBias
())
->
None
:
"""Plots a matrixplot of shape bias.
Args:
args (argparse.Namespace): arguments.
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
"""
mpl
.
rcParams
[
'font.family'
]
=
[
'serif'
]
mpl
.
rcParams
[
'font.serif'
]
=
[
'Times New Roman'
]
plt
.
figure
(
figsize
=
(
9
,
7
))
df
=
read_csvs
(
args
.
csv_dir
)
fontsize
=
15
ticklength
=
10
markersize
=
250
label_size
=
20
classes
=
df
[
'category'
].
unique
()
num_classes
=
len
(
classes
)
# plot setup
fig
=
plt
.
figure
(
1
,
figsize
=
(
12
,
12
),
dpi
=
300.
)
ax
=
plt
.
gca
()
ax
.
set_xlim
([
0
,
1
])
ax
.
set_ylim
([
-
.
5
,
num_classes
-
0.5
])
# secondary reversed x axis
ax_top
=
ax
.
secondary_xaxis
(
'top'
,
functions
=
(
lambda
x
:
1
-
x
,
lambda
x
:
1
-
x
))
# labels, ticks
plt
.
tick_params
(
axis
=
'y'
,
which
=
'both'
,
left
=
False
,
right
=
False
,
labelleft
=
False
)
ax
.
set_ylabel
(
'Shape categories'
,
labelpad
=
60
,
fontsize
=
label_size
)
ax
.
set_xlabel
(
"Fraction of 'texture' decisions"
,
fontsize
=
label_size
,
labelpad
=
25
)
ax_top
.
set_xlabel
(
"Fraction of 'shape' decisions"
,
fontsize
=
label_size
,
labelpad
=
25
)
ax
.
xaxis
.
set_major_formatter
(
FormatStrFormatter
(
'%g'
))
ax_top
.
xaxis
.
set_major_formatter
(
FormatStrFormatter
(
'%g'
))
ax
.
get_xaxis
().
set_ticks
(
np
.
arange
(
0
,
1.1
,
0.1
))
ax_top
.
set_ticks
(
np
.
arange
(
0
,
1.1
,
0.1
))
ax
.
tick_params
(
axis
=
'both'
,
which
=
'major'
,
labelsize
=
fontsize
,
length
=
ticklength
)
ax_top
.
tick_params
(
axis
=
'both'
,
which
=
'major'
,
labelsize
=
fontsize
,
length
=
ticklength
)
# arrows on x axes
plt
.
arrow
(
x
=
0
,
y
=-
1.75
,
dx
=
1
,
dy
=
0
,
fc
=
'black'
,
head_width
=
0.4
,
head_length
=
0.03
,
clip_on
=
False
,
length_includes_head
=
True
,
overhang
=
0.5
)
plt
.
arrow
(
x
=
1
,
y
=
num_classes
+
0.75
,
dx
=-
1
,
dy
=
0
,
fc
=
'black'
,
head_width
=
0.4
,
head_length
=
0.03
,
clip_on
=
False
,
length_includes_head
=
True
,
overhang
=
0.5
)
# icons besides y axis
# determine order of icons
df_selection
=
df
.
loc
[(
df
[
'subj'
].
isin
(
humans
))]
class_avgs
=
[]
for
cl
in
classes
:
df_class_selection
=
df_selection
.
query
(
"category == '{}'"
.
format
(
cl
))
class_avgs
.
append
(
1
-
analysis
.
analysis
(
df
=
df_class_selection
)[
'shape-bias'
])
sorted_indices
=
np
.
argsort
(
class_avgs
)
classes
=
classes
[
sorted_indices
]
# icon placement is calculated in axis coordinates
WIDTH
=
1
/
num_classes
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
XPOS
=
-
1.25
*
WIDTH
YPOS
=
-
0.5
HEIGHT
=
1
MARGINX
=
1
/
10
*
WIDTH
# vertical whitespace between icons
MARGINY
=
1
/
10
*
HEIGHT
# horizontal whitespace between icons
left
=
XPOS
+
MARGINX
right
=
XPOS
+
WIDTH
-
MARGINX
for
i
in
range
(
num_classes
):
bottom
=
i
+
MARGINY
+
YPOS
top
=
(
i
+
1
)
-
MARGINY
+
YPOS
iconpath
=
osp
.
join
(
ICONS_DIR
,
'{}.png'
.
format
(
classes
[
i
]))
plt
.
imshow
(
plt
.
imread
(
iconpath
),
extent
=
[
left
,
right
,
bottom
,
top
],
aspect
=
'auto'
,
clip_on
=
False
)
# plot horizontal intersection lines
for
i
in
range
(
num_classes
-
1
):
plt
.
plot
([
0
,
1
],
[
i
+
.
5
,
i
+
.
5
],
c
=
'gray'
,
linestyle
=
'dotted'
,
alpha
=
0.4
)
# plot average shapebias + scatter points
for
i
in
range
(
len
(
args
.
model_names
)):
df_selection
=
df
.
loc
[(
df
[
'subj'
].
isin
(
args
.
model_names
[
i
]))]
result_df
=
analysis
.
analysis
(
df
=
df_selection
)
avg
=
1
-
result_df
[
'shape-bias'
]
ax
.
plot
([
avg
,
avg
],
[
-
1
,
num_classes
],
color
=
args
.
colors
[
i
])
class_avgs
=
[]
for
cl
in
classes
:
df_class_selection
=
df_selection
.
query
(
"category == '{}'"
.
format
(
cl
))
class_avgs
.
append
(
1
-
analysis
.
analysis
(
df
=
df_class_selection
)[
'shape-bias'
])
ax
.
scatter
(
class_avgs
,
classes
,
color
=
args
.
colors
[
i
],
marker
=
args
.
markers
[
i
],
label
=
args
.
plotting_names
[
i
],
s
=
markersize
,
clip_on
=
False
,
edgecolors
=
PLOTTING_EDGE_COLOR
,
linewidths
=
PLOTTING_EDGE_WIDTH
,
zorder
=
3
)
plt
.
legend
(
frameon
=
True
,
labelspacing
=
1
,
loc
=
9
)
figure_path
=
osp
.
join
(
args
.
result_dir
,
'cue-conflict_shape-bias_matrixplot.pdf'
)
fig
.
savefig
(
figure_path
,
bbox_inches
=
'tight'
)
plt
.
close
()
def
check_icons
()
->
bool
:
"""Check if icons are present, if not download them."""
if
not
osp
.
exists
(
ICONS_DIR
):
return
False
for
icon_name
in
icon_names
:
if
not
osp
.
exists
(
osp
.
join
(
ICONS_DIR
,
icon_name
)):
return
False
return
True
if
__name__
==
'__main__'
:
if
not
check_icons
():
root_url
=
'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons'
# noqa: E501
os
.
makedirs
(
ICONS_DIR
,
exist_ok
=
True
)
MMLogger
.
get_current_instance
().
info
(
f
'Downloading icons to
{
ICONS_DIR
}
'
)
for
icon_name
in
icon_names
:
url
=
osp
.
join
(
root_url
,
icon_name
)
os
.
system
(
'wget -O {} {}'
.
format
(
osp
.
join
(
ICONS_DIR
,
icon_name
),
url
))
args
=
parser
.
parse_args
()
assert
len
(
args
.
model_names
)
*
3
==
len
(
args
.
colors
),
'Number of colors
\
must be 3 times the number of models. Every three colors are the RGB
\
values for one model.'
# preprocess colors
args
.
colors
=
[
c
/
255.
for
c
in
args
.
colors
]
colors
=
[]
for
i
in
range
(
len
(
args
.
model_names
)):
colors
.
append
(
args
.
colors
[
3
*
i
:
3
*
i
+
3
])
args
.
colors
=
colors
args
.
colors
.
append
([
165
/
255.
,
30
/
255.
,
55
/
255.
])
# human color
# if plotting names are not specified, use model names
if
len
(
args
.
plotting_names
)
==
0
:
args
.
plotting_names
=
args
.
model_names
# preprocess markers
args
.
markers
.
append
(
'D'
)
# human marker
# preprocess model names
args
.
model_names
=
[[
m
]
for
m
in
args
.
model_names
]
args
.
model_names
.
append
(
humans
)
# preprocess plotting names
args
.
plotting_names
.
append
(
'Humans'
)
plot_shape_bias_matrixplot
(
args
)
if
args
.
delete_icons
:
os
.
system
(
'rm -rf {}'
.
format
(
ICONS_DIR
))
tools/analysis_tools/utils.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
from
typing
import
Any
,
Dict
,
List
,
Optional
import
matplotlib
as
mpl
import
pandas
as
pd
from
matplotlib
import
_api
from
matplotlib
import
transforms
as
mtransforms
class
_DummyAxis
:
"""Define the minimal interface for a dummy axis.
Args:
minpos (float): The minimum positive value for the axis. Defaults to 0.
"""
__name__
=
'dummy'
# Once the deprecation elapses, replace dataLim and viewLim by plain
# _view_interval and _data_interval private tuples.
dataLim
=
_api
.
deprecate_privatize_attribute
(
'3.6'
,
alternative
=
'get_data_interval() and set_data_interval()'
)
viewLim
=
_api
.
deprecate_privatize_attribute
(
'3.6'
,
alternative
=
'get_view_interval() and set_view_interval()'
)
def
__init__
(
self
,
minpos
:
float
=
0
)
->
None
:
self
.
_dataLim
=
mtransforms
.
Bbox
.
unit
()
self
.
_viewLim
=
mtransforms
.
Bbox
.
unit
()
self
.
_minpos
=
minpos
def
get_view_interval
(
self
)
->
Dict
:
"""Return the view interval as a tuple (*vmin*, *vmax*)."""
return
self
.
_viewLim
.
intervalx
def
set_view_interval
(
self
,
vmin
:
float
,
vmax
:
float
)
->
None
:
"""Set the view interval to (*vmin*, *vmax*)."""
self
.
_viewLim
.
intervalx
=
vmin
,
vmax
def
get_minpos
(
self
)
->
float
:
"""Return the minimum positive value for the axis."""
return
self
.
_minpos
def
get_data_interval
(
self
)
->
Dict
:
"""Return the data interval as a tuple (*vmin*, *vmax*)."""
return
self
.
_dataLim
.
intervalx
def
set_data_interval
(
self
,
vmin
:
float
,
vmax
:
float
)
->
None
:
"""Set the data interval to (*vmin*, *vmax*)."""
self
.
_dataLim
.
intervalx
=
vmin
,
vmax
def
get_tick_space
(
self
)
->
int
:
"""Return the number of ticks to use."""
# Just use the long-standing default of nbins==9
return
9
class
TickHelper
:
"""A helper class for ticks and tick labels."""
axis
=
None
def
set_axis
(
self
,
axis
:
Any
)
->
None
:
"""Set the axis instance."""
self
.
axis
=
axis
def
create_dummy_axis
(
self
,
**
kwargs
)
->
None
:
"""Create a dummy axis if no axis is set."""
if
self
.
axis
is
None
:
self
.
axis
=
_DummyAxis
(
**
kwargs
)
@
_api
.
deprecated
(
'3.5'
,
alternative
=
'`.Axis.set_view_interval`'
)
def
set_view_interval
(
self
,
vmin
:
float
,
vmax
:
float
)
->
None
:
"""Set the view interval to (*vmin*, *vmax*)."""
self
.
axis
.
set_view_interval
(
vmin
,
vmax
)
@
_api
.
deprecated
(
'3.5'
,
alternative
=
'`.Axis.set_data_interval`'
)
def
set_data_interval
(
self
,
vmin
:
float
,
vmax
:
float
)
->
None
:
"""Set the data interval to (*vmin*, *vmax*)."""
self
.
axis
.
set_data_interval
(
vmin
,
vmax
)
@
_api
.
deprecated
(
'3.5'
,
alternative
=
'`.Axis.set_view_interval` and `.Axis.set_data_interval`'
)
def
set_bounds
(
self
,
vmin
:
float
,
vmax
:
float
)
->
None
:
"""Set the view and data interval to (*vmin*, *vmax*)."""
self
.
set_view_interval
(
vmin
,
vmax
)
self
.
set_data_interval
(
vmin
,
vmax
)
class
Formatter
(
TickHelper
):
"""Create a string based on a tick value and location."""
# some classes want to see all the locs to help format
# individual ones
locs
=
[]
def
__call__
(
self
,
x
:
str
,
pos
:
Optional
[
Any
]
=
None
)
->
str
:
"""Return the format for tick value *x* at position pos.
``pos=None`` indicates an unspecified location.
This method must be overridden in the derived class.
Args:
x (str): The tick value.
pos (Optional[Any]): The tick position. Defaults to None.
"""
raise
NotImplementedError
(
'Derived must override'
)
def
format_ticks
(
self
,
values
:
pd
.
Series
)
->
List
[
str
]:
"""Return the tick labels for all the ticks at once.
Args:
values (pd.Series): The tick values.
Returns:
List[str]: The tick labels.
"""
self
.
set_locs
(
values
)
return
[
self
(
value
,
i
)
for
i
,
value
in
enumerate
(
values
)]
def
format_data
(
self
,
value
:
Any
)
->
str
:
"""Return the full string representation of the value with the position
unspecified.
Args:
value (Any): The tick value.
Returns:
str: The full string representation of the value.
"""
return
self
.
__call__
(
value
)
def
format_data_short
(
self
,
value
:
Any
)
->
str
:
"""Return a short string version of the tick value.
Defaults to the position-independent long value.
Args:
value (Any): The tick value.
Returns:
str: The short string representation of the value.
"""
return
self
.
format_data
(
value
)
def
get_offset
(
self
)
->
str
:
"""Return the offset string."""
return
''
def
set_locs
(
self
,
locs
:
List
[
Any
])
->
None
:
"""Set the locations of the ticks.
This method is called before computing the tick labels because some
formatters need to know all tick locations to do so.
"""
self
.
locs
=
locs
@
staticmethod
def
fix_minus
(
s
:
str
)
->
str
:
"""Some classes may want to replace a hyphen for minus with the proper
Unicode symbol (U+2212) for typographical correctness.
This is a
helper method to perform such a replacement when it is enabled via
:rc:`axes.unicode_minus`.
Args:
s (str): The string to replace the hyphen with the Unicode symbol.
"""
return
(
s
.
replace
(
'-'
,
'
\N{MINUS SIGN}
'
)
if
mpl
.
rcParams
[
'axes.unicode_minus'
]
else
s
)
def
_set_locator
(
self
,
locator
:
Any
)
->
None
:
"""Subclasses may want to override this to set a locator."""
pass
class
FormatStrFormatter
(
Formatter
):
"""Use an old-style ('%' operator) format string to format the tick.
The format string should have a single variable format (%) in it.
It will be applied to the value (not the position) of the tick.
Negative numeric values will use a dash, not a Unicode minus; use mathtext
to get a Unicode minus by wrapping the format specifier with $ (e.g.
"$%g$").
Args:
fmt (str): Format string.
"""
def
__init__
(
self
,
fmt
:
str
)
->
None
:
self
.
fmt
=
fmt
def
__call__
(
self
,
x
:
str
,
pos
:
Optional
[
Any
])
->
str
:
"""Return the formatted label string.
Only the value *x* is formatted. The position is ignored.
Args:
x (str): The value to format.
pos (Any): The position of the tick. Ignored.
"""
return
self
.
fmt
%
x
class
ShapeBias
:
"""Compute the shape bias of a model.
Reference: `ImageNet-trained CNNs are biased towards texture;
increasing shape bias improves accuracy and robustness
<https://arxiv.org/abs/1811.12231>`_.
"""
num_input_models
=
1
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
plotting_name
=
'shape-bias'
@
staticmethod
def
_check_dataframe
(
df
:
pd
.
DataFrame
)
->
None
:
"""Check that the dataframe is valid."""
assert
len
(
df
)
>
0
,
'empty dataframe'
def
analysis
(
self
,
df
:
pd
.
DataFrame
)
->
Dict
[
str
,
float
]:
"""Compute the shape bias of a model.
Args:
df (pd.DataFrame): The dataframe containing the data.
Returns:
Dict[str, float]: The shape bias.
"""
self
.
_check_dataframe
(
df
)
df
=
df
.
copy
()
df
[
'correct_texture'
]
=
df
[
'imagename'
].
apply
(
self
.
get_texture_category
)
df
[
'correct_shape'
]
=
df
[
'category'
]
# remove those rows where shape = texture, i.e. no cue conflict present
df2
=
df
.
loc
[
df
.
correct_shape
!=
df
.
correct_texture
]
fraction_correct_shape
=
len
(
df2
.
loc
[
df2
.
object_response
==
df2
.
correct_shape
])
/
len
(
df
)
fraction_correct_texture
=
len
(
df2
.
loc
[
df2
.
object_response
==
df2
.
correct_texture
])
/
len
(
df
)
shape_bias
=
fraction_correct_shape
/
(
fraction_correct_shape
+
fraction_correct_texture
)
result_dict
=
{
'fraction-correct-shape'
:
fraction_correct_shape
,
'fraction-correct-texture'
:
fraction_correct_texture
,
'shape-bias'
:
shape_bias
}
return
result_dict
def
get_texture_category
(
self
,
imagename
:
str
)
->
str
:
"""Return texture category from imagename.
e.g. 'XXX_dog10-bird2.png' -> 'bird '
Args:
imagename (str): Name of the image.
Returns:
str: Texture category.
"""
assert
type
(
imagename
)
is
str
# remove unnecessary words
a
=
imagename
.
split
(
'_'
)[
-
1
]
# remove .png etc.
b
=
a
.
split
(
'.'
)[
0
]
# get texture category (last word)
c
=
b
.
split
(
'-'
)[
-
1
]
# remove number, e.g. 'bird2' -> 'bird'
d
=
''
.
join
([
i
for
i
in
c
if
not
i
.
isdigit
()])
return
d
tools/benchmarks/mmdetection/mim_dist_test.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
CFG
=
$1
CHECKPOINT
=
$2
GPUS
=
$3
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim
test
mmdet
\
$CFG
\
--checkpoint
$CHECKPOINT
\
--launcher
pytorch
\
-G
$GPUS
\
$PY_ARGS
tools/benchmarks/mmdetection/mim_dist_train_c4.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
CFG
=
$1
PRETRAIN
=
$2
# pretrained model
GPUS
=
$3
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim train mmdet
$CFG
\
--launcher
pytorch
-G
$GPUS
\
--cfg-options
model.backbone.init_cfg.type
=
Pretrained
\
model.backbone.init_cfg.checkpoint
=
$PRETRAIN
\
model.backbone.init_cfg.prefix
=
"backbone."
\
model.roi_head.shared_head.init_cfg.type
=
Pretrained
\
model.roi_head.shared_head.init_cfg.checkpoint
=
$PRETRAIN
\
model.roi_head.shared_head.init_cfg.prefix
=
"backbone."
\
$PY_ARGS
tools/benchmarks/mmdetection/mim_dist_train_fpn.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
CFG
=
$1
PRETRAIN
=
$2
# pretrained model
GPUS
=
$3
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim train mmdet
$CFG
\
--launcher
pytorch
-G
$GPUS
\
--cfg-options
model.backbone.init_cfg.type
=
Pretrained
\
model.backbone.init_cfg.checkpoint
=
$PRETRAIN
\
model.backbone.init_cfg.prefix
=
"backbone."
\
$PY_ARGS
tools/benchmarks/mmdetection/mim_slurm_test.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
CFG
=
$2
CHECKPOINT
=
$3
GPUS
=
${
GPUS
:-
8
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
8
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim
test
mmdet
\
$CFG
\
--checkpoint
$CHECKPOINT
\
--launcher
slurm
-G
$GPUS
\
--gpus-per-node
$GPUS_PER_NODE
\
--cpus-per-task
$CPUS_PER_TASK
\
--partition
$PARTITION
\
--srun-args
"
$SRUN_ARGS
"
\
$PY_ARGS
tools/benchmarks/mmdetection/mim_slurm_train_c4.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
CFG
=
$2
PRETRAIN
=
$3
# pretrained model
GPUS
=
${
GPUS
:-
8
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
8
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim train mmdet
$CFG
\
--launcher
slurm
-G
$GPUS
\
--gpus-per-node
$GPUS_PER_NODE
\
--cpus-per-task
$CPUS_PER_TASK
\
--partition
$PARTITION
\
--srun-args
"
$SRUN_ARGS
"
\
--cfg-options
model.backbone.init_cfg.type
=
Pretrained
\
model.backbone.init_cfg.checkpoint
=
$PRETRAIN
\
model.backbone.init_cfg.prefix
=
"backbone."
\
model.roi_head.shared_head.init_cfg.type
=
Pretrained
\
model.roi_head.shared_head.init_cfg.checkpoint
=
$PRETRAIN
\
model.roi_head.shared_head.init_cfg.prefix
=
"backbone."
\
$PY_ARGS
tools/benchmarks/mmdetection/mim_slurm_train_fpn.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
CFG
=
$2
PRETRAIN
=
$3
# pretrained model
GPUS
=
${
GPUS
:-
8
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
8
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim train mmdet
$CFG
\
--launcher
slurm
-G
$GPUS
\
--gpus-per-node
$GPUS_PER_NODE
\
--cpus-per-task
$CPUS_PER_TASK
\
--partition
$PARTITION
\
--srun-args
"
$SRUN_ARGS
"
\
--cfg-options
model.backbone.init_cfg.type
=
Pretrained
\
model.backbone.init_cfg.checkpoint
=
$PRETRAIN
\
model.backbone.init_cfg.prefix
=
"backbone."
\
$PY_ARGS
tools/benchmarks/mmsegmentation/mim_dist_test.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
CFG
=
$1
CHECKPOINT
=
$2
GPUS
=
$3
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim
test
mmseg
\
$CFG
\
--checkpoint
$CHECKPOINT
\
--launcher
pytorch
\
-G
$GPUS
\
$PY_ARGS
tools/benchmarks/mmsegmentation/mim_dist_train.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
CFG
=
$1
PRETRAIN
=
$2
# pretrained model
GPUS
=
$3
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim train mmseg
$CFG
\
--launcher
pytorch
-G
$GPUS
\
--cfg-options
model.backbone.init_cfg.type
=
Pretrained
\
model.backbone.init_cfg.checkpoint
=
$PRETRAIN
\
model.backbone.init_cfg.prefix
=
"backbone."
\
model.pretrained
=
None
\
$PY_ARGS
tools/benchmarks/mmsegmentation/mim_slurm_test.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
CFG
=
$2
CHECKPOINT
=
$3
GPUS
=
${
GPUS
:-
4
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
4
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim
test
mmseg
\
$CFG
\
--checkpoint
$CHECKPOINT
\
--launcher
slurm
-G
$GPUS
\
--gpus-per-node
$GPUS_PER_NODE
\
--cpus-per-task
$CPUS_PER_TASK
\
--partition
$PARTITION
\
--srun-args
"
$SRUN_ARGS
"
\
$PY_ARGS
tools/benchmarks/mmsegmentation/mim_slurm_train.sh
0 → 100644
View file @
1ac2e802
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
CFG
=
$2
PRETRAIN
=
$3
# pretrained model
GPUS
=
${
GPUS
:-
4
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
4
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PY_ARGS
=
${
@
:4
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
mim train mmseg
$CFG
\
--launcher
slurm
-G
$GPUS
\
--gpus-per-node
$GPUS_PER_NODE
\
--cpus-per-task
$CPUS_PER_TASK
\
--partition
$PARTITION
\
--srun-args
"
$SRUN_ARGS
"
\
--cfg-options
model.backbone.init_cfg.type
=
Pretrained
\
model.backbone.init_cfg.checkpoint
=
$PRETRAIN
\
model.backbone.init_cfg.prefix
=
"backbone."
\
model.pretrained
=
None
\
$PY_ARGS
tools/dataset_converters/convert_flickr30k_ann.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
"""Create COCO-Style GT annotations based on raw annotation of Flickr30k.
GT annotations are used for evaluation in image caption task.
"""
import
json
def
main
():
with
open
(
'dataset_flickr30k.json'
,
'r'
)
as
f
:
annotations
=
json
.
load
(
f
)
ann_list
=
[]
img_list
=
[]
splits
=
[
'train'
,
'val'
,
'test'
]
for
split
in
splits
:
for
img
in
annotations
[
'images'
]:
# img_example={
# "sentids": [0, 1, 2],
# "imgid": 0,
# "sentences": [
# {"raw": "Two men in green shirts standing in a yard.",
# "imgid": 0, "sentid": 0},
# {"raw": "A man in a blue shirt standing in a garden.",
# "imgid": 0, "sentid": 1},
# {"raw": "Two friends enjoy time spent together.",
# "imgid": 0, "sentid": 2}
# ],
# "split": "train",
# "filename": "1000092795.jpg"
# },
if
img
[
'split'
]
!=
split
:
continue
img_list
.
append
({
'id'
:
img
[
'imgid'
]})
for
sentence
in
img
[
'sentences'
]:
ann_info
=
{
'image_id'
:
img
[
'imgid'
],
'id'
:
sentence
[
'sentid'
],
'caption'
:
sentence
[
'raw'
]
}
ann_list
.
append
(
ann_info
)
json_file
=
{
'annotations'
:
ann_list
,
'images'
:
img_list
}
# generate flickr30k_train_gt.json, flickr30k_val_gt.json
# and flickr30k_test_gt.json
with
open
(
f
'flickr30k_
{
split
}
_gt.json'
,
'w'
)
as
f
:
json
.
dump
(
json_file
,
f
)
if
__name__
==
'__main__'
:
main
()
tools/dataset_converters/convert_imagenet_subsets.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
"""SimCLR provides list files for semi-supervised benchmarks
https://github.com/google-research/simclr/tree/master/imagenet_subsets/"""
import
argparse
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert ImageNet subset lists provided by SimCLR into '
'the required format in MMPretrain.'
)
parser
.
add_argument
(
'input'
,
help
=
'Input list file, downloaded from SimCLR github repo.'
)
parser
.
add_argument
(
'output'
,
help
=
'Output list file with the required format.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# create dict with full imagenet annotation file
with
open
(
'data/imagenet/meta/train.txt'
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
keys
=
[
line
.
split
(
'/'
)[
0
]
for
line
in
lines
]
labels
=
[
line
.
strip
().
split
()[
1
]
for
line
in
lines
]
mapping
=
{}
for
k
,
l
in
zip
(
keys
,
labels
):
if
k
not
in
mapping
:
mapping
[
k
]
=
l
else
:
assert
mapping
[
k
]
==
l
# convert
with
open
(
args
.
input
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
fns
=
[
line
.
strip
()
for
line
in
lines
]
sample_keys
=
[
line
.
split
(
'_'
)[
0
]
for
line
in
lines
]
sample_labels
=
[
mapping
[
k
]
for
k
in
sample_keys
]
output_lines
=
[
f
'
{
k
}
/
{
fn
}
{
l
}
\n
'
for
k
,
fn
,
l
in
zip
(
sample_keys
,
fns
,
sample_labels
)
]
with
open
(
args
.
output
,
'w+'
)
as
f
:
f
.
writelines
(
output_lines
)
if
__name__
==
'__main__'
:
main
()
tools/dataset_converters/convert_inaturalist.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
mmcv
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert iNaturalist2018 annotations to MMPretrain format.'
)
parser
.
add_argument
(
'input'
,
type
=
str
,
help
=
'Input annotation json file.'
)
parser
.
add_argument
(
'output'
,
type
=
str
,
help
=
'Output list file.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
data
=
mmcv
.
load
(
args
.
input
)
output_lines
=
[]
for
img_item
in
data
[
'images'
]:
for
ann_item
in
data
[
'annotations'
]:
if
ann_item
[
'image_id'
]
==
img_item
[
'id'
]:
output_lines
.
append
(
f
"
{
img_item
[
'file_name'
]
}
{
ann_item
[
'category_id'
]
}
\n
"
)
assert
len
(
output_lines
)
==
len
(
data
[
'images'
])
with
open
(
args
.
output
,
'w'
)
as
f
:
f
.
writelines
(
output_lines
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment