Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
RepViT-optimize_pytorch
Commits
c218d1c5
Commit
c218d1c5
authored
Jun 12, 2024
by
chenzk
Browse files
v1.0
parents
Pipeline
#1192
canceled with stages
Changes
195
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2568 additions
and
0 deletions
+2568
-0
segmentation/tools/analyze_logs.py
segmentation/tools/analyze_logs.py
+130
-0
segmentation/tools/benchmark.py
segmentation/tools/benchmark.py
+86
-0
segmentation/tools/browse_dataset.py
segmentation/tools/browse_dataset.py
+167
-0
segmentation/tools/convert_datasets/chase_db1.py
segmentation/tools/convert_datasets/chase_db1.py
+88
-0
segmentation/tools/convert_datasets/cityscapes.py
segmentation/tools/convert_datasets/cityscapes.py
+56
-0
segmentation/tools/convert_datasets/coco_stuff10k.py
segmentation/tools/convert_datasets/coco_stuff10k.py
+306
-0
segmentation/tools/convert_datasets/coco_stuff164k.py
segmentation/tools/convert_datasets/coco_stuff164k.py
+263
-0
segmentation/tools/convert_datasets/drive.py
segmentation/tools/convert_datasets/drive.py
+113
-0
segmentation/tools/convert_datasets/hrf.py
segmentation/tools/convert_datasets/hrf.py
+111
-0
segmentation/tools/convert_datasets/pascal_context.py
segmentation/tools/convert_datasets/pascal_context.py
+87
-0
segmentation/tools/convert_datasets/stare.py
segmentation/tools/convert_datasets/stare.py
+166
-0
segmentation/tools/convert_datasets/voc_aug.py
segmentation/tools/convert_datasets/voc_aug.py
+92
-0
segmentation/tools/deploy_test.py
segmentation/tools/deploy_test.py
+296
-0
segmentation/tools/dist_test.sh
segmentation/tools/dist_test.sh
+10
-0
segmentation/tools/dist_train.sh
segmentation/tools/dist_train.sh
+20
-0
segmentation/tools/get_flops.py
segmentation/tools/get_flops.py
+62
-0
segmentation/tools/model_converters/mit2mmseg.py
segmentation/tools/model_converters/mit2mmseg.py
+82
-0
segmentation/tools/model_converters/swin2mmseg.py
segmentation/tools/model_converters/swin2mmseg.py
+87
-0
segmentation/tools/model_converters/vit2mmseg.py
segmentation/tools/model_converters/vit2mmseg.py
+70
-0
segmentation/tools/onnx2tensorrt.py
segmentation/tools/onnx2tensorrt.py
+276
-0
No files found.
segmentation/tools/analyze_logs.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/open-
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
import
argparse
import
json
from
collections
import
defaultdict
import
matplotlib.pyplot
as
plt
import
seaborn
as
sns
def
plot_curve
(
log_dicts
,
args
):
if
args
.
backend
is
not
None
:
plt
.
switch_backend
(
args
.
backend
)
sns
.
set_style
(
args
.
style
)
# if legend is None, use {filename}_{key} as legend
legend
=
args
.
legend
if
legend
is
None
:
legend
=
[]
for
json_log
in
args
.
json_logs
:
for
metric
in
args
.
keys
:
legend
.
append
(
f
'
{
json_log
}
_
{
metric
}
'
)
assert
len
(
legend
)
==
(
len
(
args
.
json_logs
)
*
len
(
args
.
keys
))
metrics
=
args
.
keys
num_metrics
=
len
(
metrics
)
for
i
,
log_dict
in
enumerate
(
log_dicts
):
epochs
=
list
(
log_dict
.
keys
())
for
j
,
metric
in
enumerate
(
metrics
):
print
(
f
'plot curve of
{
args
.
json_logs
[
i
]
}
, metric is
{
metric
}
'
)
plot_epochs
=
[]
plot_iters
=
[]
plot_values
=
[]
# In some log files, iters number is not correct, `pre_iter` is
# used to prevent generate wrong lines.
pre_iter
=
-
1
for
epoch
in
epochs
:
epoch_logs
=
log_dict
[
epoch
]
if
metric
not
in
epoch_logs
.
keys
():
continue
if
metric
in
[
'mIoU'
,
'mAcc'
,
'aAcc'
]:
plot_epochs
.
append
(
epoch
)
plot_values
.
append
(
epoch_logs
[
metric
][
0
])
else
:
for
idx
in
range
(
len
(
epoch_logs
[
metric
])):
if
pre_iter
>
epoch_logs
[
'iter'
][
idx
]:
continue
pre_iter
=
epoch_logs
[
'iter'
][
idx
]
plot_iters
.
append
(
epoch_logs
[
'iter'
][
idx
])
plot_values
.
append
(
epoch_logs
[
metric
][
idx
])
ax
=
plt
.
gca
()
label
=
legend
[
i
*
num_metrics
+
j
]
if
metric
in
[
'mIoU'
,
'mAcc'
,
'aAcc'
]:
ax
.
set_xticks
(
plot_epochs
)
plt
.
xlabel
(
'epoch'
)
plt
.
plot
(
plot_epochs
,
plot_values
,
label
=
label
,
marker
=
'o'
)
else
:
plt
.
xlabel
(
'iter'
)
plt
.
plot
(
plot_iters
,
plot_values
,
label
=
label
,
linewidth
=
0.5
)
plt
.
legend
()
if
args
.
title
is
not
None
:
plt
.
title
(
args
.
title
)
if
args
.
out
is
None
:
plt
.
show
()
else
:
print
(
f
'save curve to:
{
args
.
out
}
'
)
plt
.
savefig
(
args
.
out
)
plt
.
cla
()
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Analyze Json Log'
)
parser
.
add_argument
(
'json_logs'
,
type
=
str
,
nargs
=
'+'
,
help
=
'path of train log in json format'
)
parser
.
add_argument
(
'--keys'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
'mIoU'
],
help
=
'the metric that you want to plot'
)
parser
.
add_argument
(
'--title'
,
type
=
str
,
help
=
'title of figure'
)
parser
.
add_argument
(
'--legend'
,
type
=
str
,
nargs
=
'+'
,
default
=
None
,
help
=
'legend of each plot'
)
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
None
,
help
=
'backend of plt'
)
parser
.
add_argument
(
'--style'
,
type
=
str
,
default
=
'dark'
,
help
=
'style of plt'
)
parser
.
add_argument
(
'--out'
,
type
=
str
,
default
=
None
)
args
=
parser
.
parse_args
()
return
args
def
load_json_logs
(
json_logs
):
# load and convert json_logs to log_dict, key is epoch, value is a sub dict
# keys of sub dict is different metrics
# value of sub dict is a list of corresponding values of all iterations
log_dicts
=
[
dict
()
for
_
in
json_logs
]
for
json_log
,
log_dict
in
zip
(
json_logs
,
log_dicts
):
with
open
(
json_log
,
'r'
)
as
log_file
:
for
line
in
log_file
:
log
=
json
.
loads
(
line
.
strip
())
# skip lines without `epoch` field
if
'epoch'
not
in
log
:
continue
epoch
=
log
.
pop
(
'epoch'
)
if
epoch
not
in
log_dict
:
log_dict
[
epoch
]
=
defaultdict
(
list
)
for
k
,
v
in
log
.
items
():
log_dict
[
epoch
][
k
].
append
(
v
)
return
log_dicts
def
main
():
args
=
parse_args
()
json_logs
=
args
.
json_logs
for
json_log
in
json_logs
:
assert
json_log
.
endswith
(
'.json'
)
log_dicts
=
load_json_logs
(
json_logs
)
plot_curve
(
log_dicts
,
args
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/benchmark.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
time
import
torch
from
mmcv
import
Config
from
mmcv.parallel
import
MMDataParallel
from
mmcv.runner
import
load_checkpoint
,
wrap_fp16_model
from
mmseg.datasets
import
build_dataloader
,
build_dataset
from
mmseg.models
import
build_segmentor
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMSeg benchmark a model'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
50
,
help
=
'interval of logging'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
# set cudnn_benchmark
torch
.
backends
.
cudnn
.
benchmark
=
False
cfg
.
model
.
pretrained
=
None
cfg
.
data
.
test
.
test_mode
=
True
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset
=
build_dataset
(
cfg
.
data
.
test
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
1
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
False
,
shuffle
=
False
)
# build the model and load checkpoint
cfg
.
model
.
train_cfg
=
None
model
=
build_segmentor
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
model
.
eval
()
# the first several iterations may be very slow so skip them
num_warmup
=
5
pure_inf_time
=
0
total_iters
=
200
# benchmark with 200 image and take the average
for
i
,
data
in
enumerate
(
data_loader
):
torch
.
cuda
.
synchronize
()
start_time
=
time
.
perf_counter
()
with
torch
.
no_grad
():
model
(
return_loss
=
False
,
rescale
=
True
,
**
data
)
torch
.
cuda
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
start_time
if
i
>=
num_warmup
:
pure_inf_time
+=
elapsed
if
(
i
+
1
)
%
args
.
log_interval
==
0
:
fps
=
(
i
+
1
-
num_warmup
)
/
pure_inf_time
print
(
f
'Done image [
{
i
+
1
:
<
3
}
/
{
total_iters
}
], '
f
'fps:
{
fps
:.
2
f
}
img / s'
)
if
(
i
+
1
)
==
total_iters
:
fps
=
(
i
+
1
-
num_warmup
)
/
pure_inf_time
print
(
f
'Overall fps:
{
fps
:.
2
f
}
img / s'
)
break
if
__name__
==
'__main__'
:
main
()
segmentation/tools/browse_dataset.py
0 → 100644
View file @
c218d1c5
import
argparse
import
os
import
warnings
from
pathlib
import
Path
import
mmcv
import
numpy
as
np
from
mmcv
import
Config
from
mmseg.datasets.builder
import
build_dataset
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Browse a dataset'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--show-origin'
,
default
=
False
,
action
=
'store_true'
,
help
=
'if True, omit all augmentation in pipeline,'
' show origin image and seg map'
)
parser
.
add_argument
(
'--skip-type'
,
type
=
str
,
nargs
=
'+'
,
default
=
[
'DefaultFormatBundle'
,
'Normalize'
,
'Collect'
],
help
=
'skip some useless pipeline,if `show-origin` is true, '
'all pipeline except `Load` will be skipped'
)
parser
.
add_argument
(
'--output-dir'
,
default
=
'./output'
,
type
=
str
,
help
=
'If there is no display interface, you can save it'
)
parser
.
add_argument
(
'--show'
,
default
=
False
,
action
=
'store_true'
)
parser
.
add_argument
(
'--show-interval'
,
type
=
int
,
default
=
999
,
help
=
'the interval of show (ms)'
)
parser
.
add_argument
(
'--opacity'
,
type
=
float
,
default
=
0.5
,
help
=
'the opacity of semantic map'
)
args
=
parser
.
parse_args
()
return
args
def
imshow_semantic
(
img
,
seg
,
class_names
,
palette
=
None
,
win_name
=
''
,
show
=
False
,
wait_time
=
0
,
out_file
=
None
,
opacity
=
0.5
):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img
=
mmcv
.
imread
(
img
)
img
=
img
.
copy
()
if
palette
is
None
:
palette
=
np
.
random
.
randint
(
0
,
255
,
size
=
(
len
(
class_names
),
3
))
palette
=
np
.
array
(
palette
)
assert
palette
.
shape
[
0
]
==
len
(
class_names
)
assert
palette
.
shape
[
1
]
==
3
assert
len
(
palette
.
shape
)
==
2
assert
0
<
opacity
<=
1.0
color_seg
=
np
.
zeros
((
seg
.
shape
[
0
],
seg
.
shape
[
1
],
3
),
dtype
=
np
.
uint8
)
for
label
,
color
in
enumerate
(
palette
):
color_seg
[
seg
==
label
,
:]
=
color
# convert to BGR
color_seg
=
color_seg
[...,
::
-
1
]
img
=
img
*
(
1
-
opacity
)
+
color_seg
*
opacity
img
=
img
.
astype
(
np
.
uint8
)
# if out_file specified, do not show image in window
if
out_file
is
not
None
:
show
=
False
if
show
:
mmcv
.
imshow
(
img
,
win_name
,
wait_time
)
if
out_file
is
not
None
:
mmcv
.
imwrite
(
img
,
out_file
)
if
not
(
show
or
out_file
):
warnings
.
warn
(
'show==False and out_file is not specified, only '
'result image will be returned'
)
return
img
def
_retrieve_data_cfg
(
_data_cfg
,
skip_type
,
show_origin
):
if
show_origin
is
True
:
# only keep pipeline of Loading data and ann
_data_cfg
[
'pipeline'
]
=
[
x
for
x
in
_data_cfg
.
pipeline
if
'Load'
in
x
[
'type'
]
]
else
:
_data_cfg
[
'pipeline'
]
=
[
x
for
x
in
_data_cfg
.
pipeline
if
x
[
'type'
]
not
in
skip_type
]
def
retrieve_data_cfg
(
config_path
,
skip_type
,
show_origin
=
False
):
cfg
=
Config
.
fromfile
(
config_path
)
train_data_cfg
=
cfg
.
data
.
train
if
isinstance
(
train_data_cfg
,
list
):
for
_data_cfg
in
train_data_cfg
:
if
'pipeline'
in
_data_cfg
:
_retrieve_data_cfg
(
_data_cfg
,
skip_type
,
show_origin
)
elif
'dataset'
in
_data_cfg
:
_retrieve_data_cfg
(
_data_cfg
[
'dataset'
],
skip_type
,
show_origin
)
else
:
raise
ValueError
elif
'dataset'
in
train_data_cfg
:
_retrieve_data_cfg
(
train_data_cfg
[
'dataset'
],
skip_type
,
show_origin
)
else
:
_retrieve_data_cfg
(
train_data_cfg
,
skip_type
,
show_origin
)
return
cfg
def
main
():
args
=
parse_args
()
cfg
=
retrieve_data_cfg
(
args
.
config
,
args
.
skip_type
,
args
.
show_origin
)
dataset
=
build_dataset
(
cfg
.
data
.
train
)
progress_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
item
in
dataset
:
filename
=
os
.
path
.
join
(
args
.
output_dir
,
Path
(
item
[
'filename'
]).
name
)
if
args
.
output_dir
is
not
None
else
None
imshow_semantic
(
item
[
'img'
],
item
[
'gt_semantic_seg'
],
dataset
.
CLASSES
,
dataset
.
PALETTE
,
show
=
args
.
show
,
wait_time
=
args
.
show_interval
,
out_file
=
filename
,
opacity
=
args
.
opacity
,
)
progress_bar
.
update
()
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/chase_db1.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
os.path
as
osp
import
tempfile
import
zipfile
import
mmcv
CHASE_DB1_LEN
=
28
*
3
TRAINING_LEN
=
60
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert CHASE_DB1 dataset to mmsegmentation format'
)
parser
.
add_argument
(
'dataset_path'
,
help
=
'path of CHASEDB1.zip'
)
parser
.
add_argument
(
'--tmp_dir'
,
help
=
'path of the temporary directory'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
dataset_path
=
args
.
dataset_path
if
args
.
out_dir
is
None
:
out_dir
=
osp
.
join
(
'data'
,
'CHASE_DB1'
)
else
:
out_dir
=
args
.
out_dir
print
(
'Making directories...'
)
mmcv
.
mkdir_or_exist
(
out_dir
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'validation'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
))
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
print
(
'Extracting CHASEDB1.zip...'
)
zip_file
=
zipfile
.
ZipFile
(
dataset_path
)
zip_file
.
extractall
(
tmp_dir
)
print
(
'Generating training dataset...'
)
assert
len
(
os
.
listdir
(
tmp_dir
))
==
CHASE_DB1_LEN
,
\
'len(os.listdir(tmp_dir)) != {}'
.
format
(
CHASE_DB1_LEN
)
for
img_name
in
sorted
(
os
.
listdir
(
tmp_dir
))[:
TRAINING_LEN
]:
img
=
mmcv
.
imread
(
osp
.
join
(
tmp_dir
,
img_name
))
if
osp
.
splitext
(
img_name
)[
1
]
==
'.jpg'
:
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'training'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
else
:
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'training'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
for
img_name
in
sorted
(
os
.
listdir
(
tmp_dir
))[
TRAINING_LEN
:]:
img
=
mmcv
.
imread
(
osp
.
join
(
tmp_dir
,
img_name
))
if
osp
.
splitext
(
img_name
)[
1
]
==
'.jpg'
:
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'validation'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
else
:
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
print
(
'Removing the temporary files...'
)
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/cityscapes.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
import
mmcv
from
cityscapesscripts.preparation.json2labelImg
import
json2labelImg
def
convert_json_to_label
(
json_file
):
label_file
=
json_file
.
replace
(
'_polygons.json'
,
'_labelTrainIds.png'
)
json2labelImg
(
json_file
,
label_file
,
'trainIds'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert Cityscapes annotations to TrainIds'
)
parser
.
add_argument
(
'cityscapes_path'
,
help
=
'cityscapes data path'
)
parser
.
add_argument
(
'--gt-dir'
,
default
=
'gtFine'
,
type
=
str
)
parser
.
add_argument
(
'-o'
,
'--out-dir'
,
help
=
'output path'
)
parser
.
add_argument
(
'--nproc'
,
default
=
1
,
type
=
int
,
help
=
'number of process'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
cityscapes_path
=
args
.
cityscapes_path
out_dir
=
args
.
out_dir
if
args
.
out_dir
else
cityscapes_path
mmcv
.
mkdir_or_exist
(
out_dir
)
gt_dir
=
osp
.
join
(
cityscapes_path
,
args
.
gt_dir
)
poly_files
=
[]
for
poly
in
mmcv
.
scandir
(
gt_dir
,
'_polygons.json'
,
recursive
=
True
):
poly_file
=
osp
.
join
(
gt_dir
,
poly
)
poly_files
.
append
(
poly_file
)
if
args
.
nproc
>
1
:
mmcv
.
track_parallel_progress
(
convert_json_to_label
,
poly_files
,
args
.
nproc
)
else
:
mmcv
.
track_progress
(
convert_json_to_label
,
poly_files
)
split_names
=
[
'train'
,
'val'
,
'test'
]
for
split
in
split_names
:
filenames
=
[]
for
poly
in
mmcv
.
scandir
(
osp
.
join
(
gt_dir
,
split
),
'_polygons.json'
,
recursive
=
True
):
filenames
.
append
(
poly
.
replace
(
'_gtFine_polygons.json'
,
''
))
with
open
(
osp
.
join
(
out_dir
,
f
'
{
split
}
.txt'
),
'w'
)
as
f
:
f
.
writelines
(
f
+
'
\n
'
for
f
in
filenames
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/coco_stuff10k.py
0 → 100644
View file @
c218d1c5
import
argparse
import
os.path
as
osp
import
shutil
from
functools
import
partial
import
mmcv
import
numpy
as
np
from
PIL
import
Image
from
scipy.io
import
loadmat
COCO_LEN
=
10000
clsID_to_trID
=
{
0
:
0
,
1
:
1
,
2
:
2
,
3
:
3
,
4
:
4
,
5
:
5
,
6
:
6
,
7
:
7
,
8
:
8
,
9
:
9
,
10
:
10
,
11
:
11
,
13
:
12
,
14
:
13
,
15
:
14
,
16
:
15
,
17
:
16
,
18
:
17
,
19
:
18
,
20
:
19
,
21
:
20
,
22
:
21
,
23
:
22
,
24
:
23
,
25
:
24
,
27
:
25
,
28
:
26
,
31
:
27
,
32
:
28
,
33
:
29
,
34
:
30
,
35
:
31
,
36
:
32
,
37
:
33
,
38
:
34
,
39
:
35
,
40
:
36
,
41
:
37
,
42
:
38
,
43
:
39
,
44
:
40
,
46
:
41
,
47
:
42
,
48
:
43
,
49
:
44
,
50
:
45
,
51
:
46
,
52
:
47
,
53
:
48
,
54
:
49
,
55
:
50
,
56
:
51
,
57
:
52
,
58
:
53
,
59
:
54
,
60
:
55
,
61
:
56
,
62
:
57
,
63
:
58
,
64
:
59
,
65
:
60
,
67
:
61
,
70
:
62
,
72
:
63
,
73
:
64
,
74
:
65
,
75
:
66
,
76
:
67
,
77
:
68
,
78
:
69
,
79
:
70
,
80
:
71
,
81
:
72
,
82
:
73
,
84
:
74
,
85
:
75
,
86
:
76
,
87
:
77
,
88
:
78
,
89
:
79
,
90
:
80
,
92
:
81
,
93
:
82
,
94
:
83
,
95
:
84
,
96
:
85
,
97
:
86
,
98
:
87
,
99
:
88
,
100
:
89
,
101
:
90
,
102
:
91
,
103
:
92
,
104
:
93
,
105
:
94
,
106
:
95
,
107
:
96
,
108
:
97
,
109
:
98
,
110
:
99
,
111
:
100
,
112
:
101
,
113
:
102
,
114
:
103
,
115
:
104
,
116
:
105
,
117
:
106
,
118
:
107
,
119
:
108
,
120
:
109
,
121
:
110
,
122
:
111
,
123
:
112
,
124
:
113
,
125
:
114
,
126
:
115
,
127
:
116
,
128
:
117
,
129
:
118
,
130
:
119
,
131
:
120
,
132
:
121
,
133
:
122
,
134
:
123
,
135
:
124
,
136
:
125
,
137
:
126
,
138
:
127
,
139
:
128
,
140
:
129
,
141
:
130
,
142
:
131
,
143
:
132
,
144
:
133
,
145
:
134
,
146
:
135
,
147
:
136
,
148
:
137
,
149
:
138
,
150
:
139
,
151
:
140
,
152
:
141
,
153
:
142
,
154
:
143
,
155
:
144
,
156
:
145
,
157
:
146
,
158
:
147
,
159
:
148
,
160
:
149
,
161
:
150
,
162
:
151
,
163
:
152
,
164
:
153
,
165
:
154
,
166
:
155
,
167
:
156
,
168
:
157
,
169
:
158
,
170
:
159
,
171
:
160
,
172
:
161
,
173
:
162
,
174
:
163
,
175
:
164
,
176
:
165
,
177
:
166
,
178
:
167
,
179
:
168
,
180
:
169
,
181
:
170
,
182
:
171
}
def
convert_to_trainID
(
tuple_path
,
in_img_dir
,
in_ann_dir
,
out_img_dir
,
out_mask_dir
,
is_train
):
imgpath
,
maskpath
=
tuple_path
shutil
.
copyfile
(
osp
.
join
(
in_img_dir
,
imgpath
),
osp
.
join
(
out_img_dir
,
'train2014'
,
imgpath
)
if
is_train
else
osp
.
join
(
out_img_dir
,
'test2014'
,
imgpath
))
annotate
=
loadmat
(
osp
.
join
(
in_ann_dir
,
maskpath
))
mask
=
annotate
[
'S'
].
astype
(
np
.
uint8
)
mask_copy
=
mask
.
copy
()
for
clsID
,
trID
in
clsID_to_trID
.
items
():
mask_copy
[
mask
==
clsID
]
=
trID
seg_filename
=
osp
.
join
(
out_mask_dir
,
'train2014'
,
maskpath
.
split
(
'.'
)[
0
]
+
'_labelTrainIds.png'
)
if
is_train
else
osp
.
join
(
out_mask_dir
,
'test2014'
,
maskpath
.
split
(
'.'
)[
0
]
+
'_labelTrainIds.png'
)
Image
.
fromarray
(
mask_copy
).
save
(
seg_filename
,
'PNG'
)
def
generate_coco_list
(
folder
):
train_list
=
osp
.
join
(
folder
,
'imageLists'
,
'train.txt'
)
test_list
=
osp
.
join
(
folder
,
'imageLists'
,
'test.txt'
)
train_paths
=
[]
test_paths
=
[]
with
open
(
train_list
)
as
f
:
for
filename
in
f
:
basename
=
filename
.
strip
()
imgpath
=
basename
+
'.jpg'
maskpath
=
basename
+
'.mat'
train_paths
.
append
((
imgpath
,
maskpath
))
with
open
(
test_list
)
as
f
:
for
filename
in
f
:
basename
=
filename
.
strip
()
imgpath
=
basename
+
'.jpg'
maskpath
=
basename
+
'.mat'
test_paths
.
append
((
imgpath
,
maskpath
))
return
train_paths
,
test_paths
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
\
'Convert COCO Stuff 10k annotations to mmsegmentation format'
)
# noqa
parser
.
add_argument
(
'coco_path'
,
help
=
'coco stuff path'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
parser
.
add_argument
(
'--nproc'
,
default
=
16
,
type
=
int
,
help
=
'number of process'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
coco_path
=
args
.
coco_path
nproc
=
args
.
nproc
out_dir
=
args
.
out_dir
or
coco_path
out_img_dir
=
osp
.
join
(
out_dir
,
'images'
)
out_mask_dir
=
osp
.
join
(
out_dir
,
'annotations'
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_img_dir
,
'train2014'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_img_dir
,
'test2014'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_mask_dir
,
'train2014'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_mask_dir
,
'test2014'
))
train_list
,
test_list
=
generate_coco_list
(
coco_path
)
assert
(
len
(
train_list
)
+
len
(
test_list
))
==
COCO_LEN
,
'Wrong length of list {} & {}'
.
format
(
len
(
train_list
),
len
(
test_list
))
if
args
.
nproc
>
1
:
mmcv
.
track_parallel_progress
(
partial
(
convert_to_trainID
,
in_img_dir
=
osp
.
join
(
coco_path
,
'images'
),
in_ann_dir
=
osp
.
join
(
coco_path
,
'annotations'
),
out_img_dir
=
out_img_dir
,
out_mask_dir
=
out_mask_dir
,
is_train
=
True
),
train_list
,
nproc
=
nproc
)
mmcv
.
track_parallel_progress
(
partial
(
convert_to_trainID
,
in_img_dir
=
osp
.
join
(
coco_path
,
'images'
),
in_ann_dir
=
osp
.
join
(
coco_path
,
'annotations'
),
out_img_dir
=
out_img_dir
,
out_mask_dir
=
out_mask_dir
,
is_train
=
False
),
test_list
,
nproc
=
nproc
)
else
:
mmcv
.
track_progress
(
partial
(
convert_to_trainID
,
in_img_dir
=
osp
.
join
(
coco_path
,
'images'
),
in_ann_dir
=
osp
.
join
(
coco_path
,
'annotations'
),
out_img_dir
=
out_img_dir
,
out_mask_dir
=
out_mask_dir
,
is_train
=
True
),
train_list
)
mmcv
.
track_progress
(
partial
(
convert_to_trainID
,
in_img_dir
=
osp
.
join
(
coco_path
,
'images'
),
in_ann_dir
=
osp
.
join
(
coco_path
,
'annotations'
),
out_img_dir
=
out_img_dir
,
out_mask_dir
=
out_mask_dir
,
is_train
=
False
),
test_list
)
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/coco_stuff164k.py
0 → 100644
View file @
c218d1c5
import
argparse
import
os.path
as
osp
import
shutil
from
functools
import
partial
from
glob
import
glob
import
mmcv
import
numpy
as
np
from
PIL
import
Image
COCO_LEN
=
123287
clsID_to_trID
=
{
0
:
0
,
1
:
1
,
2
:
2
,
3
:
3
,
4
:
4
,
5
:
5
,
6
:
6
,
7
:
7
,
8
:
8
,
9
:
9
,
10
:
10
,
12
:
11
,
13
:
12
,
14
:
13
,
15
:
14
,
16
:
15
,
17
:
16
,
18
:
17
,
19
:
18
,
20
:
19
,
21
:
20
,
22
:
21
,
23
:
22
,
24
:
23
,
26
:
24
,
27
:
25
,
30
:
26
,
31
:
27
,
32
:
28
,
33
:
29
,
34
:
30
,
35
:
31
,
36
:
32
,
37
:
33
,
38
:
34
,
39
:
35
,
40
:
36
,
41
:
37
,
42
:
38
,
43
:
39
,
45
:
40
,
46
:
41
,
47
:
42
,
48
:
43
,
49
:
44
,
50
:
45
,
51
:
46
,
52
:
47
,
53
:
48
,
54
:
49
,
55
:
50
,
56
:
51
,
57
:
52
,
58
:
53
,
59
:
54
,
60
:
55
,
61
:
56
,
62
:
57
,
63
:
58
,
64
:
59
,
66
:
60
,
69
:
61
,
71
:
62
,
72
:
63
,
73
:
64
,
74
:
65
,
75
:
66
,
76
:
67
,
77
:
68
,
78
:
69
,
79
:
70
,
80
:
71
,
81
:
72
,
83
:
73
,
84
:
74
,
85
:
75
,
86
:
76
,
87
:
77
,
88
:
78
,
89
:
79
,
91
:
80
,
92
:
81
,
93
:
82
,
94
:
83
,
95
:
84
,
96
:
85
,
97
:
86
,
98
:
87
,
99
:
88
,
100
:
89
,
101
:
90
,
102
:
91
,
103
:
92
,
104
:
93
,
105
:
94
,
106
:
95
,
107
:
96
,
108
:
97
,
109
:
98
,
110
:
99
,
111
:
100
,
112
:
101
,
113
:
102
,
114
:
103
,
115
:
104
,
116
:
105
,
117
:
106
,
118
:
107
,
119
:
108
,
120
:
109
,
121
:
110
,
122
:
111
,
123
:
112
,
124
:
113
,
125
:
114
,
126
:
115
,
127
:
116
,
128
:
117
,
129
:
118
,
130
:
119
,
131
:
120
,
132
:
121
,
133
:
122
,
134
:
123
,
135
:
124
,
136
:
125
,
137
:
126
,
138
:
127
,
139
:
128
,
140
:
129
,
141
:
130
,
142
:
131
,
143
:
132
,
144
:
133
,
145
:
134
,
146
:
135
,
147
:
136
,
148
:
137
,
149
:
138
,
150
:
139
,
151
:
140
,
152
:
141
,
153
:
142
,
154
:
143
,
155
:
144
,
156
:
145
,
157
:
146
,
158
:
147
,
159
:
148
,
160
:
149
,
161
:
150
,
162
:
151
,
163
:
152
,
164
:
153
,
165
:
154
,
166
:
155
,
167
:
156
,
168
:
157
,
169
:
158
,
170
:
159
,
171
:
160
,
172
:
161
,
173
:
162
,
174
:
163
,
175
:
164
,
176
:
165
,
177
:
166
,
178
:
167
,
179
:
168
,
180
:
169
,
181
:
170
,
255
:
255
}
def
convert_to_trainID
(
maskpath
,
out_mask_dir
,
is_train
):
mask
=
np
.
array
(
Image
.
open
(
maskpath
))
mask_copy
=
mask
.
copy
()
for
clsID
,
trID
in
clsID_to_trID
.
items
():
mask_copy
[
mask
==
clsID
]
=
trID
seg_filename
=
osp
.
join
(
out_mask_dir
,
'train2017'
,
osp
.
basename
(
maskpath
).
split
(
'.'
)[
0
]
+
'_labelTrainIds.png'
)
if
is_train
else
osp
.
join
(
out_mask_dir
,
'val2017'
,
osp
.
basename
(
maskpath
).
split
(
'.'
)[
0
]
+
'_labelTrainIds.png'
)
Image
.
fromarray
(
mask_copy
).
save
(
seg_filename
,
'PNG'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
\
'Convert COCO Stuff 164k annotations to mmsegmentation format'
)
# noqa
parser
.
add_argument
(
'coco_path'
,
help
=
'coco stuff path'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
parser
.
add_argument
(
'--nproc'
,
default
=
16
,
type
=
int
,
help
=
'number of process'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
coco_path
=
args
.
coco_path
nproc
=
args
.
nproc
out_dir
=
args
.
out_dir
or
coco_path
out_img_dir
=
osp
.
join
(
out_dir
,
'images'
)
out_mask_dir
=
osp
.
join
(
out_dir
,
'annotations'
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_mask_dir
,
'train2017'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_mask_dir
,
'val2017'
))
if
out_dir
!=
coco_path
:
shutil
.
copytree
(
osp
.
join
(
coco_path
,
'images'
),
out_img_dir
)
train_list
=
glob
(
osp
.
join
(
coco_path
,
'annotations'
,
'train2017'
,
'*.png'
))
train_list
=
[
file
for
file
in
train_list
if
'_labelTrainIds'
not
in
file
]
test_list
=
glob
(
osp
.
join
(
coco_path
,
'annotations'
,
'val2017'
,
'*.png'
))
test_list
=
[
file
for
file
in
test_list
if
'_labelTrainIds'
not
in
file
]
assert
(
len
(
train_list
)
+
len
(
test_list
))
==
COCO_LEN
,
'Wrong length of list {} & {}'
.
format
(
len
(
train_list
),
len
(
test_list
))
if
args
.
nproc
>
1
:
mmcv
.
track_parallel_progress
(
partial
(
convert_to_trainID
,
out_mask_dir
=
out_mask_dir
,
is_train
=
True
),
train_list
,
nproc
=
nproc
)
mmcv
.
track_parallel_progress
(
partial
(
convert_to_trainID
,
out_mask_dir
=
out_mask_dir
,
is_train
=
False
),
test_list
,
nproc
=
nproc
)
else
:
mmcv
.
track_progress
(
partial
(
convert_to_trainID
,
out_mask_dir
=
out_mask_dir
,
is_train
=
True
),
train_list
)
mmcv
.
track_progress
(
partial
(
convert_to_trainID
,
out_mask_dir
=
out_mask_dir
,
is_train
=
False
),
test_list
)
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/drive.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
os.path
as
osp
import
tempfile
import
zipfile
import
cv2
import
mmcv
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert DRIVE dataset to mmsegmentation format'
)
parser
.
add_argument
(
'training_path'
,
help
=
'the training part of DRIVE dataset'
)
parser
.
add_argument
(
'testing_path'
,
help
=
'the testing part of DRIVE dataset'
)
parser
.
add_argument
(
'--tmp_dir'
,
help
=
'path of the temporary directory'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
training_path
=
args
.
training_path
testing_path
=
args
.
testing_path
if
args
.
out_dir
is
None
:
out_dir
=
osp
.
join
(
'data'
,
'DRIVE'
)
else
:
out_dir
=
args
.
out_dir
print
(
'Making directories...'
)
mmcv
.
mkdir_or_exist
(
out_dir
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'validation'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
))
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
print
(
'Extracting training.zip...'
)
zip_file
=
zipfile
.
ZipFile
(
training_path
)
zip_file
.
extractall
(
tmp_dir
)
print
(
'Generating training dataset...'
)
now_dir
=
osp
.
join
(
tmp_dir
,
'training'
,
'images'
)
for
img_name
in
os
.
listdir
(
now_dir
):
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
img_name
))
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'training'
,
osp
.
splitext
(
img_name
)[
0
].
replace
(
'_training'
,
''
)
+
'.png'
))
now_dir
=
osp
.
join
(
tmp_dir
,
'training'
,
'1st_manual'
)
for
img_name
in
os
.
listdir
(
now_dir
):
cap
=
cv2
.
VideoCapture
(
osp
.
join
(
now_dir
,
img_name
))
ret
,
img
=
cap
.
read
()
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'training'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
print
(
'Extracting test.zip...'
)
zip_file
=
zipfile
.
ZipFile
(
testing_path
)
zip_file
.
extractall
(
tmp_dir
)
print
(
'Generating validation dataset...'
)
now_dir
=
osp
.
join
(
tmp_dir
,
'test'
,
'images'
)
for
img_name
in
os
.
listdir
(
now_dir
):
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
img_name
))
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'validation'
,
osp
.
splitext
(
img_name
)[
0
].
replace
(
'_test'
,
''
)
+
'.png'
))
now_dir
=
osp
.
join
(
tmp_dir
,
'test'
,
'1st_manual'
)
if
osp
.
exists
(
now_dir
):
for
img_name
in
os
.
listdir
(
now_dir
):
cap
=
cv2
.
VideoCapture
(
osp
.
join
(
now_dir
,
img_name
))
ret
,
img
=
cap
.
read
()
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
now_dir
=
osp
.
join
(
tmp_dir
,
'test'
,
'2nd_manual'
)
if
osp
.
exists
(
now_dir
):
for
img_name
in
os
.
listdir
(
now_dir
):
cap
=
cv2
.
VideoCapture
(
osp
.
join
(
now_dir
,
img_name
))
ret
,
img
=
cap
.
read
()
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
,
osp
.
splitext
(
img_name
)[
0
]
+
'.png'
))
print
(
'Removing the temporary files...'
)
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/hrf.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
os.path
as
osp
import
tempfile
import
zipfile
import
mmcv
HRF_LEN
=
15
TRAINING_LEN
=
5
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert HRF dataset to mmsegmentation format'
)
parser
.
add_argument
(
'healthy_path'
,
help
=
'the path of healthy.zip'
)
parser
.
add_argument
(
'healthy_manualsegm_path'
,
help
=
'the path of healthy_manualsegm.zip'
)
parser
.
add_argument
(
'glaucoma_path'
,
help
=
'the path of glaucoma.zip'
)
parser
.
add_argument
(
'glaucoma_manualsegm_path'
,
help
=
'the path of glaucoma_manualsegm.zip'
)
parser
.
add_argument
(
'diabetic_retinopathy_path'
,
help
=
'the path of diabetic_retinopathy.zip'
)
parser
.
add_argument
(
'diabetic_retinopathy_manualsegm_path'
,
help
=
'the path of diabetic_retinopathy_manualsegm.zip'
)
parser
.
add_argument
(
'--tmp_dir'
,
help
=
'path of the temporary directory'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
images_path
=
[
args
.
healthy_path
,
args
.
glaucoma_path
,
args
.
diabetic_retinopathy_path
]
annotations_path
=
[
args
.
healthy_manualsegm_path
,
args
.
glaucoma_manualsegm_path
,
args
.
diabetic_retinopathy_manualsegm_path
]
if
args
.
out_dir
is
None
:
out_dir
=
osp
.
join
(
'data'
,
'HRF'
)
else
:
out_dir
=
args
.
out_dir
print
(
'Making directories...'
)
mmcv
.
mkdir_or_exist
(
out_dir
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'validation'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
))
print
(
'Generating images...'
)
for
now_path
in
images_path
:
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
zip_file
=
zipfile
.
ZipFile
(
now_path
)
zip_file
.
extractall
(
tmp_dir
)
assert
len
(
os
.
listdir
(
tmp_dir
))
==
HRF_LEN
,
\
'len(os.listdir(tmp_dir)) != {}'
.
format
(
HRF_LEN
)
for
filename
in
sorted
(
os
.
listdir
(
tmp_dir
))[:
TRAINING_LEN
]:
img
=
mmcv
.
imread
(
osp
.
join
(
tmp_dir
,
filename
))
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'training'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
for
filename
in
sorted
(
os
.
listdir
(
tmp_dir
))[
TRAINING_LEN
:]:
img
=
mmcv
.
imread
(
osp
.
join
(
tmp_dir
,
filename
))
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'validation'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
print
(
'Generating annotations...'
)
for
now_path
in
annotations_path
:
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
zip_file
=
zipfile
.
ZipFile
(
now_path
)
zip_file
.
extractall
(
tmp_dir
)
assert
len
(
os
.
listdir
(
tmp_dir
))
==
HRF_LEN
,
\
'len(os.listdir(tmp_dir)) != {}'
.
format
(
HRF_LEN
)
for
filename
in
sorted
(
os
.
listdir
(
tmp_dir
))[:
TRAINING_LEN
]:
img
=
mmcv
.
imread
(
osp
.
join
(
tmp_dir
,
filename
))
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'training'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
for
filename
in
sorted
(
os
.
listdir
(
tmp_dir
))[
TRAINING_LEN
:]:
img
=
mmcv
.
imread
(
osp
.
join
(
tmp_dir
,
filename
))
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/pascal_context.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
functools
import
partial
import
mmcv
import
numpy
as
np
from
detail
import
Detail
from
PIL
import
Image
_mapping
=
np
.
sort
(
np
.
array
([
0
,
2
,
259
,
260
,
415
,
324
,
9
,
258
,
144
,
18
,
19
,
22
,
23
,
397
,
25
,
284
,
158
,
159
,
416
,
33
,
162
,
420
,
454
,
295
,
296
,
427
,
44
,
45
,
46
,
308
,
59
,
440
,
445
,
31
,
232
,
65
,
354
,
424
,
68
,
326
,
72
,
458
,
34
,
207
,
80
,
355
,
85
,
347
,
220
,
349
,
360
,
98
,
187
,
104
,
105
,
366
,
189
,
368
,
113
,
115
]))
_key
=
np
.
array
(
range
(
len
(
_mapping
))).
astype
(
'uint8'
)
def
generate_labels
(
img_id
,
detail
,
out_dir
):
def
_class_to_index
(
mask
,
_mapping
,
_key
):
# assert the values
values
=
np
.
unique
(
mask
)
for
i
in
range
(
len
(
values
)):
assert
(
values
[
i
]
in
_mapping
)
index
=
np
.
digitize
(
mask
.
ravel
(),
_mapping
,
right
=
True
)
return
_key
[
index
].
reshape
(
mask
.
shape
)
mask
=
Image
.
fromarray
(
_class_to_index
(
detail
.
getMask
(
img_id
),
_mapping
=
_mapping
,
_key
=
_key
))
filename
=
img_id
[
'file_name'
]
mask
.
save
(
osp
.
join
(
out_dir
,
filename
.
replace
(
'jpg'
,
'png'
)))
return
osp
.
splitext
(
osp
.
basename
(
filename
))[
0
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert PASCAL VOC annotations to mmsegmentation format'
)
parser
.
add_argument
(
'devkit_path'
,
help
=
'pascal voc devkit path'
)
parser
.
add_argument
(
'json_path'
,
help
=
'annoation json filepath'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
devkit_path
=
args
.
devkit_path
if
args
.
out_dir
is
None
:
out_dir
=
osp
.
join
(
devkit_path
,
'VOC2010'
,
'SegmentationClassContext'
)
else
:
out_dir
=
args
.
out_dir
json_path
=
args
.
json_path
mmcv
.
mkdir_or_exist
(
out_dir
)
img_dir
=
osp
.
join
(
devkit_path
,
'VOC2010'
,
'JPEGImages'
)
train_detail
=
Detail
(
json_path
,
img_dir
,
'train'
)
train_ids
=
train_detail
.
getImgs
()
val_detail
=
Detail
(
json_path
,
img_dir
,
'val'
)
val_ids
=
val_detail
.
getImgs
()
mmcv
.
mkdir_or_exist
(
osp
.
join
(
devkit_path
,
'VOC2010/ImageSets/SegmentationContext'
))
train_list
=
mmcv
.
track_progress
(
partial
(
generate_labels
,
detail
=
train_detail
,
out_dir
=
out_dir
),
train_ids
)
with
open
(
osp
.
join
(
devkit_path
,
'VOC2010/ImageSets/SegmentationContext'
,
'train.txt'
),
'w'
)
as
f
:
f
.
writelines
(
line
+
'
\n
'
for
line
in
sorted
(
train_list
))
val_list
=
mmcv
.
track_progress
(
partial
(
generate_labels
,
detail
=
val_detail
,
out_dir
=
out_dir
),
val_ids
)
with
open
(
osp
.
join
(
devkit_path
,
'VOC2010/ImageSets/SegmentationContext'
,
'val.txt'
),
'w'
)
as
f
:
f
.
writelines
(
line
+
'
\n
'
for
line
in
sorted
(
val_list
))
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/stare.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
gzip
import
os
import
os.path
as
osp
import
tarfile
import
tempfile
import
mmcv
STARE_LEN
=
20
TRAINING_LEN
=
10
def
un_gz
(
src
,
dst
):
g_file
=
gzip
.
GzipFile
(
src
)
with
open
(
dst
,
'wb+'
)
as
f
:
f
.
write
(
g_file
.
read
())
g_file
.
close
()
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert STARE dataset to mmsegmentation format'
)
parser
.
add_argument
(
'image_path'
,
help
=
'the path of stare-images.tar'
)
parser
.
add_argument
(
'labels_ah'
,
help
=
'the path of labels-ah.tar'
)
parser
.
add_argument
(
'labels_vk'
,
help
=
'the path of labels-vk.tar'
)
parser
.
add_argument
(
'--tmp_dir'
,
help
=
'path of the temporary directory'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
image_path
=
args
.
image_path
labels_ah
=
args
.
labels_ah
labels_vk
=
args
.
labels_vk
if
args
.
out_dir
is
None
:
out_dir
=
osp
.
join
(
'data'
,
'STARE'
)
else
:
out_dir
=
args
.
out_dir
print
(
'Making directories...'
)
mmcv
.
mkdir_or_exist
(
out_dir
)
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'images'
,
'validation'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'training'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
))
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
mmcv
.
mkdir_or_exist
(
osp
.
join
(
tmp_dir
,
'gz'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
tmp_dir
,
'files'
))
print
(
'Extracting stare-images.tar...'
)
with
tarfile
.
open
(
image_path
)
as
f
:
f
.
extractall
(
osp
.
join
(
tmp_dir
,
'gz'
))
for
filename
in
os
.
listdir
(
osp
.
join
(
tmp_dir
,
'gz'
)):
un_gz
(
osp
.
join
(
tmp_dir
,
'gz'
,
filename
),
osp
.
join
(
tmp_dir
,
'files'
,
osp
.
splitext
(
filename
)[
0
]))
now_dir
=
osp
.
join
(
tmp_dir
,
'files'
)
assert
len
(
os
.
listdir
(
now_dir
))
==
STARE_LEN
,
\
'len(os.listdir(now_dir)) != {}'
.
format
(
STARE_LEN
)
for
filename
in
sorted
(
os
.
listdir
(
now_dir
))[:
TRAINING_LEN
]:
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
filename
))
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'training'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
for
filename
in
sorted
(
os
.
listdir
(
now_dir
))[
TRAINING_LEN
:]:
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
filename
))
mmcv
.
imwrite
(
img
,
osp
.
join
(
out_dir
,
'images'
,
'validation'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
print
(
'Removing the temporary files...'
)
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
mmcv
.
mkdir_or_exist
(
osp
.
join
(
tmp_dir
,
'gz'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
tmp_dir
,
'files'
))
print
(
'Extracting labels-ah.tar...'
)
with
tarfile
.
open
(
labels_ah
)
as
f
:
f
.
extractall
(
osp
.
join
(
tmp_dir
,
'gz'
))
for
filename
in
os
.
listdir
(
osp
.
join
(
tmp_dir
,
'gz'
)):
un_gz
(
osp
.
join
(
tmp_dir
,
'gz'
,
filename
),
osp
.
join
(
tmp_dir
,
'files'
,
osp
.
splitext
(
filename
)[
0
]))
now_dir
=
osp
.
join
(
tmp_dir
,
'files'
)
assert
len
(
os
.
listdir
(
now_dir
))
==
STARE_LEN
,
\
'len(os.listdir(now_dir)) != {}'
.
format
(
STARE_LEN
)
for
filename
in
sorted
(
os
.
listdir
(
now_dir
))[:
TRAINING_LEN
]:
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
filename
))
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a threshold
# to convert the nonstandard annotation imgs. The value divided by
# 128 equivalent to '1 if value >= 128 else 0'
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'training'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
for
filename
in
sorted
(
os
.
listdir
(
now_dir
))[
TRAINING_LEN
:]:
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
filename
))
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
print
(
'Removing the temporary files...'
)
with
tempfile
.
TemporaryDirectory
(
dir
=
args
.
tmp_dir
)
as
tmp_dir
:
mmcv
.
mkdir_or_exist
(
osp
.
join
(
tmp_dir
,
'gz'
))
mmcv
.
mkdir_or_exist
(
osp
.
join
(
tmp_dir
,
'files'
))
print
(
'Extracting labels-vk.tar...'
)
with
tarfile
.
open
(
labels_vk
)
as
f
:
f
.
extractall
(
osp
.
join
(
tmp_dir
,
'gz'
))
for
filename
in
os
.
listdir
(
osp
.
join
(
tmp_dir
,
'gz'
)):
un_gz
(
osp
.
join
(
tmp_dir
,
'gz'
,
filename
),
osp
.
join
(
tmp_dir
,
'files'
,
osp
.
splitext
(
filename
)[
0
]))
now_dir
=
osp
.
join
(
tmp_dir
,
'files'
)
assert
len
(
os
.
listdir
(
now_dir
))
==
STARE_LEN
,
\
'len(os.listdir(now_dir)) != {}'
.
format
(
STARE_LEN
)
for
filename
in
sorted
(
os
.
listdir
(
now_dir
))[:
TRAINING_LEN
]:
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
filename
))
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'training'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
for
filename
in
sorted
(
os
.
listdir
(
now_dir
))[
TRAINING_LEN
:]:
img
=
mmcv
.
imread
(
osp
.
join
(
now_dir
,
filename
))
mmcv
.
imwrite
(
img
[:,
:,
0
]
//
128
,
osp
.
join
(
out_dir
,
'annotations'
,
'validation'
,
osp
.
splitext
(
filename
)[
0
]
+
'.png'
))
print
(
'Removing the temporary files...'
)
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/convert_datasets/voc_aug.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
functools
import
partial
import
mmcv
import
numpy
as
np
from
PIL
import
Image
from
scipy.io
import
loadmat
AUG_LEN
=
10582
def
convert_mat
(
mat_file
,
in_dir
,
out_dir
):
data
=
loadmat
(
osp
.
join
(
in_dir
,
mat_file
))
mask
=
data
[
'GTcls'
][
0
][
'Segmentation'
][
0
].
astype
(
np
.
uint8
)
seg_filename
=
osp
.
join
(
out_dir
,
mat_file
.
replace
(
'.mat'
,
'.png'
))
Image
.
fromarray
(
mask
).
save
(
seg_filename
,
'PNG'
)
def
generate_aug_list
(
merged_list
,
excluded_list
):
return
list
(
set
(
merged_list
)
-
set
(
excluded_list
))
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert PASCAL VOC annotations to mmsegmentation format'
)
parser
.
add_argument
(
'devkit_path'
,
help
=
'pascal voc devkit path'
)
parser
.
add_argument
(
'aug_path'
,
help
=
'pascal voc aug path'
)
parser
.
add_argument
(
'-o'
,
'--out_dir'
,
help
=
'output path'
)
parser
.
add_argument
(
'--nproc'
,
default
=
1
,
type
=
int
,
help
=
'number of process'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
devkit_path
=
args
.
devkit_path
aug_path
=
args
.
aug_path
nproc
=
args
.
nproc
if
args
.
out_dir
is
None
:
out_dir
=
osp
.
join
(
devkit_path
,
'VOC2012'
,
'SegmentationClassAug'
)
else
:
out_dir
=
args
.
out_dir
mmcv
.
mkdir_or_exist
(
out_dir
)
in_dir
=
osp
.
join
(
aug_path
,
'dataset'
,
'cls'
)
mmcv
.
track_parallel_progress
(
partial
(
convert_mat
,
in_dir
=
in_dir
,
out_dir
=
out_dir
),
list
(
mmcv
.
scandir
(
in_dir
,
suffix
=
'.mat'
)),
nproc
=
nproc
)
full_aug_list
=
[]
with
open
(
osp
.
join
(
aug_path
,
'dataset'
,
'train.txt'
))
as
f
:
full_aug_list
+=
[
line
.
strip
()
for
line
in
f
]
with
open
(
osp
.
join
(
aug_path
,
'dataset'
,
'val.txt'
))
as
f
:
full_aug_list
+=
[
line
.
strip
()
for
line
in
f
]
with
open
(
osp
.
join
(
devkit_path
,
'VOC2012/ImageSets/Segmentation'
,
'train.txt'
))
as
f
:
ori_train_list
=
[
line
.
strip
()
for
line
in
f
]
with
open
(
osp
.
join
(
devkit_path
,
'VOC2012/ImageSets/Segmentation'
,
'val.txt'
))
as
f
:
val_list
=
[
line
.
strip
()
for
line
in
f
]
aug_train_list
=
generate_aug_list
(
ori_train_list
+
full_aug_list
,
val_list
)
assert
len
(
aug_train_list
)
==
AUG_LEN
,
'len(aug_train_list) != {}'
.
format
(
AUG_LEN
)
with
open
(
osp
.
join
(
devkit_path
,
'VOC2012/ImageSets/Segmentation'
,
'trainaug.txt'
),
'w'
)
as
f
:
f
.
writelines
(
line
+
'
\n
'
for
line
in
aug_train_list
)
aug_list
=
generate_aug_list
(
full_aug_list
,
ori_train_list
+
val_list
)
assert
len
(
aug_list
)
==
AUG_LEN
-
len
(
ori_train_list
),
'len(aug_list) != {}'
.
format
(
AUG_LEN
-
len
(
ori_train_list
))
with
open
(
osp
.
join
(
devkit_path
,
'VOC2012/ImageSets/Segmentation'
,
'aug.txt'
),
'w'
)
as
f
:
f
.
writelines
(
line
+
'
\n
'
for
line
in
aug_list
)
print
(
'Done!'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/deploy_test.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
os.path
as
osp
import
shutil
import
warnings
from
typing
import
Any
,
Iterable
import
mmcv
import
numpy
as
np
import
torch
from
mmcv.parallel
import
MMDataParallel
from
mmcv.runner
import
get_dist_info
from
mmcv.utils
import
DictAction
from
mmseg.apis
import
single_gpu_test
from
mmseg.datasets
import
build_dataloader
,
build_dataset
from
mmseg.models.segmentors.base
import
BaseSegmentor
from
mmseg.ops
import
resize
class
ONNXRuntimeSegmentor
(
BaseSegmentor
):
def
__init__
(
self
,
onnx_file
:
str
,
cfg
:
Any
,
device_id
:
int
):
super
(
ONNXRuntimeSegmentor
,
self
).
__init__
()
import
onnxruntime
as
ort
# get the custom op path
ort_custom_op_path
=
''
try
:
from
mmcv.ops
import
get_onnxruntime_op_path
ort_custom_op_path
=
get_onnxruntime_op_path
()
except
(
ImportError
,
ModuleNotFoundError
):
warnings
.
warn
(
'If input model has custom op from mmcv,
\
you may have to build mmcv with ONNXRuntime from source.'
)
session_options
=
ort
.
SessionOptions
()
# register custom op for onnxruntime
if
osp
.
exists
(
ort_custom_op_path
):
session_options
.
register_custom_ops_library
(
ort_custom_op_path
)
sess
=
ort
.
InferenceSession
(
onnx_file
,
session_options
)
providers
=
[
'CPUExecutionProvider'
]
options
=
[{}]
is_cuda_available
=
ort
.
get_device
()
==
'GPU'
if
is_cuda_available
:
providers
.
insert
(
0
,
'CUDAExecutionProvider'
)
options
.
insert
(
0
,
{
'device_id'
:
device_id
})
sess
.
set_providers
(
providers
,
options
)
self
.
sess
=
sess
self
.
device_id
=
device_id
self
.
io_binding
=
sess
.
io_binding
()
self
.
output_names
=
[
_
.
name
for
_
in
sess
.
get_outputs
()]
for
name
in
self
.
output_names
:
self
.
io_binding
.
bind_output
(
name
)
self
.
cfg
=
cfg
self
.
test_mode
=
cfg
.
model
.
test_cfg
.
mode
self
.
is_cuda_available
=
is_cuda_available
def
extract_feat
(
self
,
imgs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
encode_decode
(
self
,
img
,
img_metas
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_train
(
self
,
imgs
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
simple_test
(
self
,
img
:
torch
.
Tensor
,
img_meta
:
Iterable
,
**
kwargs
)
->
list
:
if
not
self
.
is_cuda_available
:
img
=
img
.
detach
().
cpu
()
elif
self
.
device_id
>=
0
:
img
=
img
.
cuda
(
self
.
device_id
)
device_type
=
img
.
device
.
type
self
.
io_binding
.
bind_input
(
name
=
'input'
,
device_type
=
device_type
,
device_id
=
self
.
device_id
,
element_type
=
np
.
float32
,
shape
=
img
.
shape
,
buffer_ptr
=
img
.
data_ptr
())
self
.
sess
.
run_with_iobinding
(
self
.
io_binding
)
seg_pred
=
self
.
io_binding
.
copy_outputs_to_cpu
()[
0
]
# whole might support dynamic reshape
ori_shape
=
img_meta
[
0
][
'ori_shape'
]
if
not
(
ori_shape
[
0
]
==
seg_pred
.
shape
[
-
2
]
and
ori_shape
[
1
]
==
seg_pred
.
shape
[
-
1
]):
seg_pred
=
torch
.
from_numpy
(
seg_pred
).
float
()
seg_pred
=
resize
(
seg_pred
,
size
=
tuple
(
ori_shape
[:
2
]),
mode
=
'nearest'
)
seg_pred
=
seg_pred
.
long
().
detach
().
cpu
().
numpy
()
seg_pred
=
seg_pred
[
0
]
seg_pred
=
list
(
seg_pred
)
return
seg_pred
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
class
TensorRTSegmentor
(
BaseSegmentor
):
def
__init__
(
self
,
trt_file
:
str
,
cfg
:
Any
,
device_id
:
int
):
super
(
TensorRTSegmentor
,
self
).
__init__
()
from
mmcv.tensorrt
import
TRTWraper
,
load_tensorrt_plugin
try
:
load_tensorrt_plugin
()
except
(
ImportError
,
ModuleNotFoundError
):
warnings
.
warn
(
'If input model has custom op from mmcv,
\
you may have to build mmcv with TensorRT from source.'
)
model
=
TRTWraper
(
trt_file
,
input_names
=
[
'input'
],
output_names
=
[
'output'
])
self
.
model
=
model
self
.
device_id
=
device_id
self
.
cfg
=
cfg
self
.
test_mode
=
cfg
.
model
.
test_cfg
.
mode
def
extract_feat
(
self
,
imgs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
encode_decode
(
self
,
img
,
img_metas
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
forward_train
(
self
,
imgs
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
simple_test
(
self
,
img
:
torch
.
Tensor
,
img_meta
:
Iterable
,
**
kwargs
)
->
list
:
with
torch
.
cuda
.
device
(
self
.
device_id
),
torch
.
no_grad
():
seg_pred
=
self
.
model
({
'input'
:
img
})[
'output'
]
seg_pred
=
seg_pred
.
detach
().
cpu
().
numpy
()
# whole might support dynamic reshape
ori_shape
=
img_meta
[
0
][
'ori_shape'
]
if
not
(
ori_shape
[
0
]
==
seg_pred
.
shape
[
-
2
]
and
ori_shape
[
1
]
==
seg_pred
.
shape
[
-
1
]):
seg_pred
=
torch
.
from_numpy
(
seg_pred
).
float
()
seg_pred
=
resize
(
seg_pred
,
size
=
tuple
(
ori_shape
[:
2
]),
mode
=
'nearest'
)
seg_pred
=
seg_pred
.
long
().
detach
().
cpu
().
numpy
()
seg_pred
=
seg_pred
[
0
]
seg_pred
=
list
(
seg_pred
)
return
seg_pred
def
aug_test
(
self
,
imgs
,
img_metas
,
**
kwargs
):
raise
NotImplementedError
(
'This method is not implemented.'
)
def
parse_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
(
description
=
'mmseg backend test (and eval)'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'model'
,
help
=
'Input model file'
)
parser
.
add_argument
(
'--backend'
,
help
=
'Backend of the model.'
,
choices
=
[
'onnxruntime'
,
'tensorrt'
])
parser
.
add_argument
(
'--out'
,
help
=
'output result file in pickle format'
)
parser
.
add_argument
(
'--format-only'
,
action
=
'store_true'
,
help
=
'Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server'
)
parser
.
add_argument
(
'--eval'
,
type
=
str
,
nargs
=
'+'
,
help
=
'evaluation metrics, which depends on the dataset, e.g., "mIoU"'
' for generic datasets, and "cityscapes" for Cityscapes'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where painted images will be saved'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'custom options'
)
parser
.
add_argument
(
'--eval-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'custom options for evaluation'
)
parser
.
add_argument
(
'--opacity'
,
type
=
float
,
default
=
0.5
,
help
=
'Opacity of painted segmentation map. In (0, 1] range.'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
def
main
():
args
=
parse_args
()
assert
args
.
out
or
args
.
eval
or
args
.
format_only
or
args
.
show
\
or
args
.
show_dir
,
\
(
'Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"'
)
if
args
.
eval
and
args
.
format_only
:
raise
ValueError
(
'--eval and --format_only cannot be both specified'
)
if
args
.
out
is
not
None
and
not
args
.
out
.
endswith
((
'.pkl'
,
'.pickle'
)):
raise
ValueError
(
'The output file must be a pkl file.'
)
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
if
args
.
options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
options
)
cfg
.
model
.
pretrained
=
None
cfg
.
data
.
test
.
test_mode
=
True
# init distributed env first, since logger depends on the dist info.
distributed
=
False
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset
=
build_dataset
(
cfg
.
data
.
test
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
1
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
# load onnx config and meta
cfg
.
model
.
train_cfg
=
None
if
args
.
backend
==
'onnxruntime'
:
model
=
ONNXRuntimeSegmentor
(
args
.
model
,
cfg
=
cfg
,
device_id
=
0
)
elif
args
.
backend
==
'tensorrt'
:
model
=
TensorRTSegmentor
(
args
.
model
,
cfg
=
cfg
,
device_id
=
0
)
model
.
CLASSES
=
dataset
.
CLASSES
model
.
PALETTE
=
dataset
.
PALETTE
# clean gpu memory when starting a new evaluation.
torch
.
cuda
.
empty_cache
()
eval_kwargs
=
{}
if
args
.
eval_options
is
None
else
args
.
eval_options
# Deprecated
efficient_test
=
eval_kwargs
.
get
(
'efficient_test'
,
False
)
if
efficient_test
:
warnings
.
warn
(
'``efficient_test=True`` does not have effect in tools/test.py, '
'the evaluation and format results are CPU memory efficient by '
'default'
)
eval_on_format_results
=
(
args
.
eval
is
not
None
and
'cityscapes'
in
args
.
eval
)
if
eval_on_format_results
:
assert
len
(
args
.
eval
)
==
1
,
'eval on format results is not '
\
'applicable for metrics other than '
\
'cityscapes'
if
args
.
format_only
or
eval_on_format_results
:
if
'imgfile_prefix'
in
eval_kwargs
:
tmpdir
=
eval_kwargs
[
'imgfile_prefix'
]
else
:
tmpdir
=
'.format_cityscapes'
eval_kwargs
.
setdefault
(
'imgfile_prefix'
,
tmpdir
)
mmcv
.
mkdir_or_exist
(
tmpdir
)
else
:
tmpdir
=
None
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
results
=
single_gpu_test
(
model
,
data_loader
,
args
.
show
,
args
.
show_dir
,
False
,
args
.
opacity
,
pre_eval
=
args
.
eval
is
not
None
and
not
eval_on_format_results
,
format_only
=
args
.
format_only
or
eval_on_format_results
,
format_args
=
eval_kwargs
)
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
if
args
.
out
:
warnings
.
warn
(
'The behavior of ``args.out`` has been changed since MMSeg '
'v0.16, the pickled outputs could be seg map as type of '
'np.array, pre-eval results or file paths for '
'``dataset.format_results()``.'
)
print
(
f
'
\n
writing results to
{
args
.
out
}
'
)
mmcv
.
dump
(
results
,
args
.
out
)
if
args
.
eval
:
dataset
.
evaluate
(
results
,
args
.
eval
,
**
eval_kwargs
)
if
tmpdir
is
not
None
and
eval_on_format_results
:
# remove tmp dir when cityscapes evaluation
shutil
.
rmtree
(
tmpdir
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/dist_test.sh
0 → 100644
View file @
c218d1c5
#!/usr/bin/env bash
CONFIG
=
$1
CHECKPOINT
=
$2
GPUS
=
$3
PORT
=
${
PORT
:-
29500
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
NCCL_P2P_DISABLE
=
1
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/test.py
$CONFIG
$CHECKPOINT
--launcher
pytorch
${
@
:4
}
segmentation/tools/dist_train.sh
0 → 100644
View file @
c218d1c5
#!/usr/bin/env bash
CONFIG
=
$1
GPUS
=
$2
NNODES
=
${
NNODES
:-
1
}
NODE_RANK
=
${
NODE_RANK
:-
0
}
PORT
=
${
PORT
:-
29500
}
MASTER_ADDR
=
${
MASTER_ADDR
:-
"127.0.0.1"
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
NCCL_P2P_DISABLE
=
1
\
python
-m
torch.distributed.launch
\
--nnodes
=
$NNODES
\
--node_rank
=
$NODE_RANK
\
--master_addr
=
$MASTER_ADDR
\
--nproc_per_node
=
$GPUS
\
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/train.py
\
$CONFIG
\
--launcher
pytorch
${
@
:3
}
segmentation/tools/get_flops.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
mmcv
import
Config
from
mmcv.cnn
import
get_model_complexity_info
from
mmseg.models
import
build_segmentor
import
sys
sys
.
path
.
append
(
".."
)
import
xformer
import
pvt
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a segmentor'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--shape'
,
type
=
int
,
nargs
=
'+'
,
default
=
[
2048
,
1024
],
help
=
'input image size'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
if
len
(
args
.
shape
)
==
1
:
input_shape
=
(
3
,
args
.
shape
[
0
],
args
.
shape
[
0
])
elif
len
(
args
.
shape
)
==
2
:
input_shape
=
(
3
,
)
+
tuple
(
args
.
shape
)
else
:
raise
ValueError
(
'invalid input shape'
)
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
.
model
.
pretrained
=
None
model
=
build_segmentor
(
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
)).
cuda
()
model
.
eval
()
if
hasattr
(
model
,
'forward_dummy'
):
model
.
forward
=
model
.
forward_dummy
else
:
raise
NotImplementedError
(
'FLOPs counter is currently not currently supported with {}'
.
format
(
model
.
__class__
.
__name__
))
flops
,
params
=
get_model_complexity_info
(
model
,
input_shape
)
split_line
=
'='
*
30
print
(
'{0}
\n
Input shape: {1}
\n
Flops: {2}
\n
Params: {3}
\n
{0}'
.
format
(
split_line
,
input_shape
,
flops
,
params
))
print
(
'!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.'
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/model_converters/mit2mmseg.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmcv
import
torch
from
mmcv.runner
import
CheckpointLoader
def
convert_mit
(
ckpt
):
new_ckpt
=
OrderedDict
()
# Process the concat between q linear weights and kv linear weights
for
k
,
v
in
ckpt
.
items
():
if
k
.
startswith
(
'head'
):
continue
# patch embedding conversion
elif
k
.
startswith
(
'patch_embed'
):
stage_i
=
int
(
k
.
split
(
'.'
)[
0
].
replace
(
'patch_embed'
,
''
))
new_k
=
k
.
replace
(
f
'patch_embed
{
stage_i
}
'
,
f
'layers.
{
stage_i
-
1
}
.0'
)
new_v
=
v
if
'proj.'
in
new_k
:
new_k
=
new_k
.
replace
(
'proj.'
,
'projection.'
)
# transformer encoder layer conversion
elif
k
.
startswith
(
'block'
):
stage_i
=
int
(
k
.
split
(
'.'
)[
0
].
replace
(
'block'
,
''
))
new_k
=
k
.
replace
(
f
'block
{
stage_i
}
'
,
f
'layers.
{
stage_i
-
1
}
.1'
)
new_v
=
v
if
'attn.q.'
in
new_k
:
sub_item_k
=
k
.
replace
(
'q.'
,
'kv.'
)
new_k
=
new_k
.
replace
(
'q.'
,
'attn.in_proj_'
)
new_v
=
torch
.
cat
([
v
,
ckpt
[
sub_item_k
]],
dim
=
0
)
elif
'attn.kv.'
in
new_k
:
continue
elif
'attn.proj.'
in
new_k
:
new_k
=
new_k
.
replace
(
'proj.'
,
'attn.out_proj.'
)
elif
'attn.sr.'
in
new_k
:
new_k
=
new_k
.
replace
(
'sr.'
,
'sr.'
)
elif
'mlp.'
in
new_k
:
string
=
f
'
{
new_k
}
-'
new_k
=
new_k
.
replace
(
'mlp.'
,
'ffn.layers.'
)
if
'fc1.weight'
in
new_k
or
'fc2.weight'
in
new_k
:
new_v
=
v
.
reshape
((
*
v
.
shape
,
1
,
1
))
new_k
=
new_k
.
replace
(
'fc1.'
,
'0.'
)
new_k
=
new_k
.
replace
(
'dwconv.dwconv.'
,
'1.'
)
new_k
=
new_k
.
replace
(
'fc2.'
,
'4.'
)
string
+=
f
'
{
new_k
}
{
v
.
shape
}
-
{
new_v
.
shape
}
'
# norm layer conversion
elif
k
.
startswith
(
'norm'
):
stage_i
=
int
(
k
.
split
(
'.'
)[
0
].
replace
(
'norm'
,
''
))
new_k
=
k
.
replace
(
f
'norm
{
stage_i
}
'
,
f
'layers.
{
stage_i
-
1
}
.2'
)
new_v
=
v
else
:
new_k
=
k
new_v
=
v
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in official pretrained segformer to '
'MMSegmentation style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
elif
'model'
in
checkpoint
:
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
weight
=
convert_mit
(
state_dict
)
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/model_converters/swin2mmseg.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmcv
import
torch
from
mmcv.runner
import
CheckpointLoader
def
convert_swin
(
ckpt
):
new_ckpt
=
OrderedDict
()
def
correct_unfold_reduction_order
(
x
):
out_channel
,
in_channel
=
x
.
shape
x
=
x
.
reshape
(
out_channel
,
4
,
in_channel
//
4
)
x
=
x
[:,
[
0
,
2
,
1
,
3
],
:].
transpose
(
1
,
2
).
reshape
(
out_channel
,
in_channel
)
return
x
def
correct_unfold_norm_order
(
x
):
in_channel
=
x
.
shape
[
0
]
x
=
x
.
reshape
(
4
,
in_channel
//
4
)
x
=
x
[[
0
,
2
,
1
,
3
],
:].
transpose
(
0
,
1
).
reshape
(
in_channel
)
return
x
for
k
,
v
in
ckpt
.
items
():
if
k
.
startswith
(
'head'
):
continue
elif
k
.
startswith
(
'layers'
):
new_v
=
v
if
'attn.'
in
k
:
new_k
=
k
.
replace
(
'attn.'
,
'attn.w_msa.'
)
elif
'mlp.'
in
k
:
if
'mlp.fc1.'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc1.'
,
'ffn.layers.0.0.'
)
elif
'mlp.fc2.'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc2.'
,
'ffn.layers.1.'
)
else
:
new_k
=
k
.
replace
(
'mlp.'
,
'ffn.'
)
elif
'downsample'
in
k
:
new_k
=
k
if
'reduction.'
in
k
:
new_v
=
correct_unfold_reduction_order
(
v
)
elif
'norm.'
in
k
:
new_v
=
correct_unfold_norm_order
(
v
)
else
:
new_k
=
k
new_k
=
new_k
.
replace
(
'layers'
,
'stages'
,
1
)
elif
k
.
startswith
(
'patch_embed'
):
new_v
=
v
if
'proj'
in
k
:
new_k
=
k
.
replace
(
'proj'
,
'projection'
)
else
:
new_k
=
k
else
:
new_v
=
v
new_k
=
k
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in official pretrained swin models to'
'MMSegmentation style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
elif
'model'
in
checkpoint
:
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
weight
=
convert_swin
(
state_dict
)
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/model_converters/vit2mmseg.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmcv
import
torch
from
mmcv.runner
import
CheckpointLoader
def
convert_vit
(
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
ckpt
.
items
():
if
k
.
startswith
(
'head'
):
continue
if
k
.
startswith
(
'norm'
):
new_k
=
k
.
replace
(
'norm.'
,
'ln1.'
)
elif
k
.
startswith
(
'patch_embed'
):
if
'proj'
in
k
:
new_k
=
k
.
replace
(
'proj'
,
'projection'
)
else
:
new_k
=
k
elif
k
.
startswith
(
'blocks'
):
if
'norm'
in
k
:
new_k
=
k
.
replace
(
'norm'
,
'ln'
)
elif
'mlp.fc1'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc1'
,
'ffn.layers.0.0'
)
elif
'mlp.fc2'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc2'
,
'ffn.layers.1'
)
elif
'attn.qkv'
in
k
:
new_k
=
k
.
replace
(
'attn.qkv.'
,
'attn.attn.in_proj_'
)
elif
'attn.proj'
in
k
:
new_k
=
k
.
replace
(
'attn.proj'
,
'attn.attn.out_proj'
)
else
:
new_k
=
k
new_k
=
new_k
.
replace
(
'blocks.'
,
'layers.'
)
else
:
new_k
=
k
new_ckpt
[
new_k
]
=
v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in timm pretrained vit models to '
'MMSegmentation style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
# timm checkpoint
state_dict
=
checkpoint
[
'state_dict'
]
elif
'model'
in
checkpoint
:
# deit checkpoint
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
weight
=
convert_vit
(
state_dict
)
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
segmentation/tools/onnx2tensorrt.py
0 → 100644
View file @
c218d1c5
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
os.path
as
osp
from
typing
import
Iterable
,
Optional
,
Union
import
matplotlib.pyplot
as
plt
import
mmcv
import
numpy
as
np
import
onnxruntime
as
ort
import
torch
from
mmcv.ops
import
get_onnxruntime_op_path
from
mmcv.tensorrt
import
(
TRTWraper
,
is_tensorrt_plugin_loaded
,
onnx2trt
,
save_trt_engine
)
from
mmseg.apis.inference
import
LoadImage
from
mmseg.datasets
import
DATASETS
from
mmseg.datasets.pipelines
import
Compose
def
get_GiB
(
x
:
int
):
"""return x GiB."""
return
x
*
(
1
<<
30
)
def
_prepare_input_img
(
img_path
:
str
,
test_pipeline
:
Iterable
[
dict
],
shape
:
Optional
[
Iterable
]
=
None
,
rescale_shape
:
Optional
[
Iterable
]
=
None
)
->
dict
:
# build the data pipeline
if
shape
is
not
None
:
test_pipeline
[
1
][
'img_scale'
]
=
(
shape
[
1
],
shape
[
0
])
test_pipeline
[
1
][
'transforms'
][
0
][
'keep_ratio'
]
=
False
test_pipeline
=
[
LoadImage
()]
+
test_pipeline
[
1
:]
test_pipeline
=
Compose
(
test_pipeline
)
# prepare data
data
=
dict
(
img
=
img_path
)
data
=
test_pipeline
(
data
)
imgs
=
data
[
'img'
]
img_metas
=
[
i
.
data
for
i
in
data
[
'img_metas'
]]
if
rescale_shape
is
not
None
:
for
img_meta
in
img_metas
:
img_meta
[
'ori_shape'
]
=
tuple
(
rescale_shape
)
+
(
3
,
)
mm_inputs
=
{
'imgs'
:
imgs
,
'img_metas'
:
img_metas
}
return
mm_inputs
def
_update_input_img
(
img_list
:
Iterable
,
img_meta_list
:
Iterable
):
# update img and its meta list
N
=
img_list
[
0
].
size
(
0
)
img_meta
=
img_meta_list
[
0
][
0
]
img_shape
=
img_meta
[
'img_shape'
]
ori_shape
=
img_meta
[
'ori_shape'
]
pad_shape
=
img_meta
[
'pad_shape'
]
new_img_meta_list
=
[[{
'img_shape'
:
img_shape
,
'ori_shape'
:
ori_shape
,
'pad_shape'
:
pad_shape
,
'filename'
:
img_meta
[
'filename'
],
'scale_factor'
:
(
img_shape
[
1
]
/
ori_shape
[
1
],
img_shape
[
0
]
/
ori_shape
[
0
])
*
2
,
'flip'
:
False
,
}
for
_
in
range
(
N
)]]
return
img_list
,
new_img_meta_list
def
show_result_pyplot
(
img
:
Union
[
str
,
np
.
ndarray
],
result
:
np
.
ndarray
,
palette
:
Optional
[
Iterable
]
=
None
,
fig_size
:
Iterable
[
int
]
=
(
15
,
10
),
opacity
:
float
=
0.5
,
title
:
str
=
''
,
block
:
bool
=
True
):
img
=
mmcv
.
imread
(
img
)
img
=
img
.
copy
()
seg
=
result
[
0
]
seg
=
mmcv
.
imresize
(
seg
,
img
.
shape
[:
2
][::
-
1
])
palette
=
np
.
array
(
palette
)
assert
palette
.
shape
[
1
]
==
3
assert
len
(
palette
.
shape
)
==
2
assert
0
<
opacity
<=
1.0
color_seg
=
np
.
zeros
((
seg
.
shape
[
0
],
seg
.
shape
[
1
],
3
),
dtype
=
np
.
uint8
)
for
label
,
color
in
enumerate
(
palette
):
color_seg
[
seg
==
label
,
:]
=
color
# convert to BGR
color_seg
=
color_seg
[...,
::
-
1
]
img
=
img
*
(
1
-
opacity
)
+
color_seg
*
opacity
img
=
img
.
astype
(
np
.
uint8
)
plt
.
figure
(
figsize
=
fig_size
)
plt
.
imshow
(
mmcv
.
bgr2rgb
(
img
))
plt
.
title
(
title
)
plt
.
tight_layout
()
plt
.
show
(
block
=
block
)
def
onnx2tensorrt
(
onnx_file
:
str
,
trt_file
:
str
,
config
:
dict
,
input_config
:
dict
,
fp16
:
bool
=
False
,
verify
:
bool
=
False
,
show
:
bool
=
False
,
dataset
:
str
=
'CityscapesDataset'
,
workspace_size
:
int
=
1
,
verbose
:
bool
=
False
):
import
tensorrt
as
trt
min_shape
=
input_config
[
'min_shape'
]
max_shape
=
input_config
[
'max_shape'
]
# create trt engine and wrapper
opt_shape_dict
=
{
'input'
:
[
min_shape
,
min_shape
,
max_shape
]}
max_workspace_size
=
get_GiB
(
workspace_size
)
trt_engine
=
onnx2trt
(
onnx_file
,
opt_shape_dict
,
log_level
=
trt
.
Logger
.
VERBOSE
if
verbose
else
trt
.
Logger
.
ERROR
,
fp16_mode
=
fp16
,
max_workspace_size
=
max_workspace_size
)
save_dir
,
_
=
osp
.
split
(
trt_file
)
if
save_dir
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
save_trt_engine
(
trt_engine
,
trt_file
)
print
(
f
'Successfully created TensorRT engine:
{
trt_file
}
'
)
if
verify
:
inputs
=
_prepare_input_img
(
input_config
[
'input_path'
],
config
.
data
.
test
.
pipeline
,
shape
=
min_shape
[
2
:])
imgs
=
inputs
[
'imgs'
]
img_metas
=
inputs
[
'img_metas'
]
img_list
=
[
img
[
None
,
:]
for
img
in
imgs
]
img_meta_list
=
[[
img_meta
]
for
img_meta
in
img_metas
]
# update img_meta
img_list
,
img_meta_list
=
_update_input_img
(
img_list
,
img_meta_list
)
if
max_shape
[
0
]
>
1
:
# concate flip image for batch test
flip_img_list
=
[
_
.
flip
(
-
1
)
for
_
in
img_list
]
img_list
=
[
torch
.
cat
((
ori_img
,
flip_img
),
0
)
for
ori_img
,
flip_img
in
zip
(
img_list
,
flip_img_list
)
]
# Get results from ONNXRuntime
ort_custom_op_path
=
get_onnxruntime_op_path
()
session_options
=
ort
.
SessionOptions
()
if
osp
.
exists
(
ort_custom_op_path
):
session_options
.
register_custom_ops_library
(
ort_custom_op_path
)
sess
=
ort
.
InferenceSession
(
onnx_file
,
session_options
)
sess
.
set_providers
([
'CPUExecutionProvider'
],
[{}])
# use cpu mode
onnx_output
=
sess
.
run
([
'output'
],
{
'input'
:
img_list
[
0
].
detach
().
numpy
()})[
0
][
0
]
# Get results from TensorRT
trt_model
=
TRTWraper
(
trt_file
,
[
'input'
],
[
'output'
])
with
torch
.
no_grad
():
trt_outputs
=
trt_model
({
'input'
:
img_list
[
0
].
contiguous
().
cuda
()})
trt_output
=
trt_outputs
[
'output'
][
0
].
cpu
().
detach
().
numpy
()
if
show
:
dataset
=
DATASETS
.
get
(
dataset
)
assert
dataset
is
not
None
palette
=
dataset
.
PALETTE
show_result_pyplot
(
input_config
[
'input_path'
],
(
onnx_output
[
0
].
astype
(
np
.
uint8
),
),
palette
=
palette
,
title
=
'ONNXRuntime'
,
block
=
False
)
show_result_pyplot
(
input_config
[
'input_path'
],
(
trt_output
[
0
].
astype
(
np
.
uint8
),
),
palette
=
palette
,
title
=
'TensorRT'
)
np
.
testing
.
assert_allclose
(
onnx_output
,
trt_output
,
rtol
=
1e-03
,
atol
=
1e-05
)
print
(
'TensorRT and ONNXRuntime output all close.'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert MMSegmentation models from ONNX to TensorRT'
)
parser
.
add_argument
(
'config'
,
help
=
'Config file of the model'
)
parser
.
add_argument
(
'model'
,
help
=
'Path to the input ONNX model'
)
parser
.
add_argument
(
'--trt-file'
,
type
=
str
,
help
=
'Path to the output TensorRT engine'
)
parser
.
add_argument
(
'--max-shape'
,
type
=
int
,
nargs
=
4
,
default
=
[
1
,
3
,
400
,
600
],
help
=
'Maximum shape of model input.'
)
parser
.
add_argument
(
'--min-shape'
,
type
=
int
,
nargs
=
4
,
default
=
[
1
,
3
,
400
,
600
],
help
=
'Minimum shape of model input.'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Enable fp16 mode'
)
parser
.
add_argument
(
'--workspace-size'
,
type
=
int
,
default
=
1
,
help
=
'Max workspace size in GiB'
)
parser
.
add_argument
(
'--input-img'
,
type
=
str
,
default
=
''
,
help
=
'Image for test'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'Whether to show output results'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'CityscapesDataset'
,
help
=
'Dataset name'
)
parser
.
add_argument
(
'--verify'
,
action
=
'store_true'
,
help
=
'Verify the outputs of ONNXRuntime and TensorRT'
)
parser
.
add_argument
(
'--verbose'
,
action
=
'store_true'
,
help
=
'Whether to verbose logging messages while creating
\
TensorRT engine.'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
assert
is_tensorrt_plugin_loaded
(),
'TensorRT plugin should be compiled.'
args
=
parse_args
()
if
not
args
.
input_img
:
args
.
input_img
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../demo/demo.png'
)
# check arguments
assert
osp
.
exists
(
args
.
config
),
'Config {} not found.'
.
format
(
args
.
config
)
assert
osp
.
exists
(
args
.
model
),
\
'ONNX model {} not found.'
.
format
(
args
.
model
)
assert
args
.
workspace_size
>=
0
,
'Workspace size less than 0.'
assert
DATASETS
.
get
(
args
.
dataset
)
is
not
None
,
\
'Dataset {} does not found.'
.
format
(
args
.
dataset
)
for
max_value
,
min_value
in
zip
(
args
.
max_shape
,
args
.
min_shape
):
assert
max_value
>=
min_value
,
\
'max_shape should be larger than min shape'
input_config
=
{
'min_shape'
:
args
.
min_shape
,
'max_shape'
:
args
.
max_shape
,
'input_path'
:
args
.
input_img
}
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
onnx2tensorrt
(
args
.
model
,
args
.
trt_file
,
cfg
,
input_config
,
fp16
=
args
.
fp16
,
verify
=
args
.
verify
,
show
=
args
.
show
,
dataset
=
args
.
dataset
,
workspace_size
=
args
.
workspace_size
,
verbose
=
args
.
verbose
)
Prev
1
…
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment