Commit c218d1c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1192 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/open-
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
import argparse
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
def plot_curve(log_dicts, args):
if args.backend is not None:
plt.switch_backend(args.backend)
sns.set_style(args.style)
# if legend is None, use {filename}_{key} as legend
legend = args.legend
if legend is None:
legend = []
for json_log in args.json_logs:
for metric in args.keys:
legend.append(f'{json_log}_{metric}')
assert len(legend) == (len(args.json_logs) * len(args.keys))
metrics = args.keys
num_metrics = len(metrics)
for i, log_dict in enumerate(log_dicts):
epochs = list(log_dict.keys())
for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
plot_epochs = []
plot_iters = []
plot_values = []
# In some log files, iters number is not correct, `pre_iter` is
# used to prevent generate wrong lines.
pre_iter = -1
for epoch in epochs:
epoch_logs = log_dict[epoch]
if metric not in epoch_logs.keys():
continue
if metric in ['mIoU', 'mAcc', 'aAcc']:
plot_epochs.append(epoch)
plot_values.append(epoch_logs[metric][0])
else:
for idx in range(len(epoch_logs[metric])):
if pre_iter > epoch_logs['iter'][idx]:
continue
pre_iter = epoch_logs['iter'][idx]
plot_iters.append(epoch_logs['iter'][idx])
plot_values.append(epoch_logs[metric][idx])
ax = plt.gca()
label = legend[i * num_metrics + j]
if metric in ['mIoU', 'mAcc', 'aAcc']:
ax.set_xticks(plot_epochs)
plt.xlabel('epoch')
plt.plot(plot_epochs, plot_values, label=label, marker='o')
else:
plt.xlabel('iter')
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
plt.legend()
if args.title is not None:
plt.title(args.title)
if args.out is None:
plt.show()
else:
print(f'save curve to: {args.out}')
plt.savefig(args.out)
plt.cla()
def parse_args():
parser = argparse.ArgumentParser(description='Analyze Json Log')
parser.add_argument(
'json_logs',
type=str,
nargs='+',
help='path of train log in json format')
parser.add_argument(
'--keys',
type=str,
nargs='+',
default=['mIoU'],
help='the metric that you want to plot')
parser.add_argument('--title', type=str, help='title of figure')
parser.add_argument(
'--legend',
type=str,
nargs='+',
default=None,
help='legend of each plot')
parser.add_argument(
'--backend', type=str, default=None, help='backend of plt')
parser.add_argument(
'--style', type=str, default='dark', help='style of plt')
parser.add_argument('--out', type=str, default=None)
args = parser.parse_args()
return args
def load_json_logs(json_logs):
# load and convert json_logs to log_dict, key is epoch, value is a sub dict
# keys of sub dict is different metrics
# value of sub dict is a list of corresponding values of all iterations
log_dicts = [dict() for _ in json_logs]
for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log, 'r') as log_file:
for line in log_file:
log = json.loads(line.strip())
# skip lines without `epoch` field
if 'epoch' not in log:
continue
epoch = log.pop('epoch')
if epoch not in log_dict:
log_dict[epoch] = defaultdict(list)
for k, v in log.items():
log_dict[epoch][k].append(v)
return log_dicts
def main():
args = parse_args()
json_logs = args.json_logs
for json_log in json_logs:
assert json_log.endswith('.json')
log_dicts = load_json_logs(json_logs)
plot_curve(log_dicts, args)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import time
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint, wrap_fp16_model
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
def parse_args():
parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--log-interval', type=int, default=50, help='interval of logging')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
torch.backends.cudnn.benchmark = False
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
model.eval()
# the first several iterations may be very slow so skip them
num_warmup = 5
pure_inf_time = 0
total_iters = 200
# benchmark with 200 image and take the average
for i, data in enumerate(data_loader):
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(return_loss=False, rescale=True, **data)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {total_iters}], '
f'fps: {fps:.2f} img / s')
if (i + 1) == total_iters:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Overall fps: {fps:.2f} img / s')
break
if __name__ == '__main__':
main()
import argparse
import os
import warnings
from pathlib import Path
import mmcv
import numpy as np
from mmcv import Config
from mmseg.datasets.builder import build_dataset
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--show-origin',
default=False,
action='store_true',
help='if True, omit all augmentation in pipeline,'
' show origin image and seg map')
parser.add_argument(
'--skip-type',
type=str,
nargs='+',
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
help='skip some useless pipeline,if `show-origin` is true, '
'all pipeline except `Load` will be skipped')
parser.add_argument(
'--output-dir',
default='./output',
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='the opacity of semantic map')
args = parser.parse_args()
return args
def imshow_semantic(img,
seg,
class_names,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if palette is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
palette = np.array(palette)
assert palette.shape[0] == len(class_names)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
if show_origin is True:
# only keep pipeline of Loading data and ann
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if 'Load' in x['type']
]
else:
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if x['type'] not in skip_type
]
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
cfg = Config.fromfile(config_path)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
elif 'dataset' in _data_cfg:
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
show_origin)
else:
raise ValueError
elif 'dataset' in train_data_cfg:
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
else:
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg
def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
dataset = build_dataset(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
imshow_semantic(
item['img'],
item['gt_semantic_seg'],
dataset.CLASSES,
dataset.PALETTE,
show=args.show,
wait_time=args.show_interval,
out_file=filename,
opacity=args.opacity,
)
progress_bar.update()
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
CHASE_DB1_LEN = 28 * 3
TRAINING_LEN = 60
def parse_args():
parser = argparse.ArgumentParser(
description='Convert CHASE_DB1 dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = osp.join('data', 'CHASE_DB1')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(out_dir)
mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
print('Extracting CHASEDB1.zip...')
zip_file = zipfile.ZipFile(dataset_path)
zip_file.extractall(tmp_dir)
print('Generating training dataset...')
assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN)
for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, img_name))
if osp.splitext(img_name)[1] == '.jpg':
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'training',
osp.splitext(img_name)[0] + '.png'))
else:
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(img_name)[0] + '.png'))
for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(tmp_dir, img_name))
if osp.splitext(img_name)[1] == '.jpg':
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'validation',
osp.splitext(img_name)[0] + '.png'))
else:
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(img_name)[0] + '.png'))
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import mmcv
from cityscapesscripts.preparation.json2labelImg import json2labelImg
def convert_json_to_label(json_file):
label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
json2labelImg(json_file, label_file, 'trainIds')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert Cityscapes annotations to TrainIds')
parser.add_argument('cityscapes_path', help='cityscapes data path')
parser.add_argument('--gt-dir', default='gtFine', type=str)
parser.add_argument('-o', '--out-dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
cityscapes_path = args.cityscapes_path
out_dir = args.out_dir if args.out_dir else cityscapes_path
mmcv.mkdir_or_exist(out_dir)
gt_dir = osp.join(cityscapes_path, args.gt_dir)
poly_files = []
for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True):
poly_file = osp.join(gt_dir, poly)
poly_files.append(poly_file)
if args.nproc > 1:
mmcv.track_parallel_progress(convert_json_to_label, poly_files,
args.nproc)
else:
mmcv.track_progress(convert_json_to_label, poly_files)
split_names = ['train', 'val', 'test']
for split in split_names:
filenames = []
for poly in mmcv.scandir(
osp.join(gt_dir, split), '_polygons.json', recursive=True):
filenames.append(poly.replace('_gtFine_polygons.json', ''))
with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
f.writelines(f + '\n' for f in filenames)
if __name__ == '__main__':
main()
import argparse
import os.path as osp
import shutil
from functools import partial
import mmcv
import numpy as np
from PIL import Image
from scipy.io import loadmat
COCO_LEN = 10000
clsID_to_trID = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
13: 12,
14: 13,
15: 14,
16: 15,
17: 16,
18: 17,
19: 18,
20: 19,
21: 20,
22: 21,
23: 22,
24: 23,
25: 24,
27: 25,
28: 26,
31: 27,
32: 28,
33: 29,
34: 30,
35: 31,
36: 32,
37: 33,
38: 34,
39: 35,
40: 36,
41: 37,
42: 38,
43: 39,
44: 40,
46: 41,
47: 42,
48: 43,
49: 44,
50: 45,
51: 46,
52: 47,
53: 48,
54: 49,
55: 50,
56: 51,
57: 52,
58: 53,
59: 54,
60: 55,
61: 56,
62: 57,
63: 58,
64: 59,
65: 60,
67: 61,
70: 62,
72: 63,
73: 64,
74: 65,
75: 66,
76: 67,
77: 68,
78: 69,
79: 70,
80: 71,
81: 72,
82: 73,
84: 74,
85: 75,
86: 76,
87: 77,
88: 78,
89: 79,
90: 80,
92: 81,
93: 82,
94: 83,
95: 84,
96: 85,
97: 86,
98: 87,
99: 88,
100: 89,
101: 90,
102: 91,
103: 92,
104: 93,
105: 94,
106: 95,
107: 96,
108: 97,
109: 98,
110: 99,
111: 100,
112: 101,
113: 102,
114: 103,
115: 104,
116: 105,
117: 106,
118: 107,
119: 108,
120: 109,
121: 110,
122: 111,
123: 112,
124: 113,
125: 114,
126: 115,
127: 116,
128: 117,
129: 118,
130: 119,
131: 120,
132: 121,
133: 122,
134: 123,
135: 124,
136: 125,
137: 126,
138: 127,
139: 128,
140: 129,
141: 130,
142: 131,
143: 132,
144: 133,
145: 134,
146: 135,
147: 136,
148: 137,
149: 138,
150: 139,
151: 140,
152: 141,
153: 142,
154: 143,
155: 144,
156: 145,
157: 146,
158: 147,
159: 148,
160: 149,
161: 150,
162: 151,
163: 152,
164: 153,
165: 154,
166: 155,
167: 156,
168: 157,
169: 158,
170: 159,
171: 160,
172: 161,
173: 162,
174: 163,
175: 164,
176: 165,
177: 166,
178: 167,
179: 168,
180: 169,
181: 170,
182: 171
}
def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir,
out_mask_dir, is_train):
imgpath, maskpath = tuple_path
shutil.copyfile(
osp.join(in_img_dir, imgpath),
osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join(
out_img_dir, 'test2014', imgpath))
annotate = loadmat(osp.join(in_ann_dir, maskpath))
mask = annotate['S'].astype(np.uint8)
mask_copy = mask.copy()
for clsID, trID in clsID_to_trID.items():
mask_copy[mask == clsID] = trID
seg_filename = osp.join(out_mask_dir, 'train2014',
maskpath.split('.')[0] +
'_labelTrainIds.png') if is_train else osp.join(
out_mask_dir, 'test2014',
maskpath.split('.')[0] + '_labelTrainIds.png')
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
def generate_coco_list(folder):
train_list = osp.join(folder, 'imageLists', 'train.txt')
test_list = osp.join(folder, 'imageLists', 'test.txt')
train_paths = []
test_paths = []
with open(train_list) as f:
for filename in f:
basename = filename.strip()
imgpath = basename + '.jpg'
maskpath = basename + '.mat'
train_paths.append((imgpath, maskpath))
with open(test_list) as f:
for filename in f:
basename = filename.strip()
imgpath = basename + '.jpg'
maskpath = basename + '.mat'
test_paths.append((imgpath, maskpath))
return train_paths, test_paths
def parse_args():
parser = argparse.ArgumentParser(
description=\
'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa
parser.add_argument('coco_path', help='coco stuff path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=16, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
coco_path = args.coco_path
nproc = args.nproc
out_dir = args.out_dir or coco_path
out_img_dir = osp.join(out_dir, 'images')
out_mask_dir = osp.join(out_dir, 'annotations')
mmcv.mkdir_or_exist(osp.join(out_img_dir, 'train2014'))
mmcv.mkdir_or_exist(osp.join(out_img_dir, 'test2014'))
mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2014'))
mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'test2014'))
train_list, test_list = generate_coco_list(coco_path)
assert (len(train_list) +
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
len(train_list), len(test_list))
if args.nproc > 1:
mmcv.track_parallel_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=True),
train_list,
nproc=nproc)
mmcv.track_parallel_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=False),
test_list,
nproc=nproc)
else:
mmcv.track_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=True), train_list)
mmcv.track_progress(
partial(
convert_to_trainID,
in_img_dir=osp.join(coco_path, 'images'),
in_ann_dir=osp.join(coco_path, 'annotations'),
out_img_dir=out_img_dir,
out_mask_dir=out_mask_dir,
is_train=False), test_list)
print('Done!')
if __name__ == '__main__':
main()
import argparse
import os.path as osp
import shutil
from functools import partial
from glob import glob
import mmcv
import numpy as np
from PIL import Image
COCO_LEN = 123287
clsID_to_trID = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
12: 11,
13: 12,
14: 13,
15: 14,
16: 15,
17: 16,
18: 17,
19: 18,
20: 19,
21: 20,
22: 21,
23: 22,
24: 23,
26: 24,
27: 25,
30: 26,
31: 27,
32: 28,
33: 29,
34: 30,
35: 31,
36: 32,
37: 33,
38: 34,
39: 35,
40: 36,
41: 37,
42: 38,
43: 39,
45: 40,
46: 41,
47: 42,
48: 43,
49: 44,
50: 45,
51: 46,
52: 47,
53: 48,
54: 49,
55: 50,
56: 51,
57: 52,
58: 53,
59: 54,
60: 55,
61: 56,
62: 57,
63: 58,
64: 59,
66: 60,
69: 61,
71: 62,
72: 63,
73: 64,
74: 65,
75: 66,
76: 67,
77: 68,
78: 69,
79: 70,
80: 71,
81: 72,
83: 73,
84: 74,
85: 75,
86: 76,
87: 77,
88: 78,
89: 79,
91: 80,
92: 81,
93: 82,
94: 83,
95: 84,
96: 85,
97: 86,
98: 87,
99: 88,
100: 89,
101: 90,
102: 91,
103: 92,
104: 93,
105: 94,
106: 95,
107: 96,
108: 97,
109: 98,
110: 99,
111: 100,
112: 101,
113: 102,
114: 103,
115: 104,
116: 105,
117: 106,
118: 107,
119: 108,
120: 109,
121: 110,
122: 111,
123: 112,
124: 113,
125: 114,
126: 115,
127: 116,
128: 117,
129: 118,
130: 119,
131: 120,
132: 121,
133: 122,
134: 123,
135: 124,
136: 125,
137: 126,
138: 127,
139: 128,
140: 129,
141: 130,
142: 131,
143: 132,
144: 133,
145: 134,
146: 135,
147: 136,
148: 137,
149: 138,
150: 139,
151: 140,
152: 141,
153: 142,
154: 143,
155: 144,
156: 145,
157: 146,
158: 147,
159: 148,
160: 149,
161: 150,
162: 151,
163: 152,
164: 153,
165: 154,
166: 155,
167: 156,
168: 157,
169: 158,
170: 159,
171: 160,
172: 161,
173: 162,
174: 163,
175: 164,
176: 165,
177: 166,
178: 167,
179: 168,
180: 169,
181: 170,
255: 255
}
def convert_to_trainID(maskpath, out_mask_dir, is_train):
mask = np.array(Image.open(maskpath))
mask_copy = mask.copy()
for clsID, trID in clsID_to_trID.items():
mask_copy[mask == clsID] = trID
seg_filename = osp.join(
out_mask_dir, 'train2017',
osp.basename(maskpath).split('.')[0] +
'_labelTrainIds.png') if is_train else osp.join(
out_mask_dir, 'val2017',
osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png')
Image.fromarray(mask_copy).save(seg_filename, 'PNG')
def parse_args():
parser = argparse.ArgumentParser(
description=\
'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa
parser.add_argument('coco_path', help='coco stuff path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=16, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
coco_path = args.coco_path
nproc = args.nproc
out_dir = args.out_dir or coco_path
out_img_dir = osp.join(out_dir, 'images')
out_mask_dir = osp.join(out_dir, 'annotations')
mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
if out_dir != coco_path:
shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
train_list = [file for file in train_list if '_labelTrainIds' not in file]
test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
test_list = [file for file in test_list if '_labelTrainIds' not in file]
assert (len(train_list) +
len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
len(train_list), len(test_list))
if args.nproc > 1:
mmcv.track_parallel_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
train_list,
nproc=nproc)
mmcv.track_parallel_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
test_list,
nproc=nproc)
else:
mmcv.track_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
train_list)
mmcv.track_progress(
partial(
convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
test_list)
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import cv2
import mmcv
def parse_args():
parser = argparse.ArgumentParser(
description='Convert DRIVE dataset to mmsegmentation format')
parser.add_argument(
'training_path', help='the training part of DRIVE dataset')
parser.add_argument(
'testing_path', help='the testing part of DRIVE dataset')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
training_path = args.training_path
testing_path = args.testing_path
if args.out_dir is None:
out_dir = osp.join('data', 'DRIVE')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(out_dir)
mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
print('Extracting training.zip...')
zip_file = zipfile.ZipFile(training_path)
zip_file.extractall(tmp_dir)
print('Generating training dataset...')
now_dir = osp.join(tmp_dir, 'training', 'images')
for img_name in os.listdir(now_dir):
img = mmcv.imread(osp.join(now_dir, img_name))
mmcv.imwrite(
img,
osp.join(
out_dir, 'images', 'training',
osp.splitext(img_name)[0].replace('_training', '') +
'.png'))
now_dir = osp.join(tmp_dir, 'training', '1st_manual')
for img_name in os.listdir(now_dir):
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
ret, img = cap.read()
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(img_name)[0] + '.png'))
print('Extracting test.zip...')
zip_file = zipfile.ZipFile(testing_path)
zip_file.extractall(tmp_dir)
print('Generating validation dataset...')
now_dir = osp.join(tmp_dir, 'test', 'images')
for img_name in os.listdir(now_dir):
img = mmcv.imread(osp.join(now_dir, img_name))
mmcv.imwrite(
img,
osp.join(
out_dir, 'images', 'validation',
osp.splitext(img_name)[0].replace('_test', '') + '.png'))
now_dir = osp.join(tmp_dir, 'test', '1st_manual')
if osp.exists(now_dir):
for img_name in os.listdir(now_dir):
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
ret, img = cap.read()
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(img_name)[0] + '.png'))
now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
if osp.exists(now_dir):
for img_name in os.listdir(now_dir):
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
ret, img = cap.read()
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(img_name)[0] + '.png'))
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
HRF_LEN = 15
TRAINING_LEN = 5
def parse_args():
parser = argparse.ArgumentParser(
description='Convert HRF dataset to mmsegmentation format')
parser.add_argument('healthy_path', help='the path of healthy.zip')
parser.add_argument(
'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip')
parser.add_argument('glaucoma_path', help='the path of glaucoma.zip')
parser.add_argument(
'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip')
parser.add_argument(
'diabetic_retinopathy_path',
help='the path of diabetic_retinopathy.zip')
parser.add_argument(
'diabetic_retinopathy_manualsegm_path',
help='the path of diabetic_retinopathy_manualsegm.zip')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
images_path = [
args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path
]
annotations_path = [
args.healthy_manualsegm_path, args.glaucoma_manualsegm_path,
args.diabetic_retinopathy_manualsegm_path
]
if args.out_dir is None:
out_dir = osp.join('data', 'HRF')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(out_dir)
mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
print('Generating images...')
for now_path in images_path:
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
zip_file = zipfile.ZipFile(now_path)
zip_file.extractall(tmp_dir)
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN)
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(tmp_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Generating annotations...')
for now_path in annotations_path:
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
zip_file = zipfile.ZipFile(now_path)
zip_file.extractall(tmp_dir)
assert len(os.listdir(tmp_dir)) == HRF_LEN, \
'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN)
for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(tmp_dir, filename))
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a
# threshold to convert the nonstandard annotation imgs. The
# value divided by 128 is equivalent to '1 if value >= 128
# else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(tmp_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from detail import Detail
from PIL import Image
_mapping = np.sort(
np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
]))
_key = np.array(range(len(_mapping))).astype('uint8')
def generate_labels(img_id, detail, out_dir):
def _class_to_index(mask, _mapping, _key):
# assert the values
values = np.unique(mask)
for i in range(len(values)):
assert (values[i] in _mapping)
index = np.digitize(mask.ravel(), _mapping, right=True)
return _key[index].reshape(mask.shape)
mask = Image.fromarray(
_class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
filename = img_id['file_name']
mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
return osp.splitext(osp.basename(filename))[0]
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmsegmentation format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('json_path', help='annoation json filepath')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
if args.out_dir is None:
out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
else:
out_dir = args.out_dir
json_path = args.json_path
mmcv.mkdir_or_exist(out_dir)
img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
train_detail = Detail(json_path, img_dir, 'train')
train_ids = train_detail.getImgs()
val_detail = Detail(json_path, img_dir, 'val')
val_ids = val_detail.getImgs()
mmcv.mkdir_or_exist(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
train_list = mmcv.track_progress(
partial(generate_labels, detail=train_detail, out_dir=out_dir),
train_ids)
with open(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
'train.txt'), 'w') as f:
f.writelines(line + '\n' for line in sorted(train_list))
val_list = mmcv.track_progress(
partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
with open(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
'val.txt'), 'w') as f:
f.writelines(line + '\n' for line in sorted(val_list))
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import gzip
import os
import os.path as osp
import tarfile
import tempfile
import mmcv
STARE_LEN = 20
TRAINING_LEN = 10
def un_gz(src, dst):
g_file = gzip.GzipFile(src)
with open(dst, 'wb+') as f:
f.write(g_file.read())
g_file.close()
def parse_args():
parser = argparse.ArgumentParser(
description='Convert STARE dataset to mmsegmentation format')
parser.add_argument('image_path', help='the path of stare-images.tar')
parser.add_argument('labels_ah', help='the path of labels-ah.tar')
parser.add_argument('labels_vk', help='the path of labels-vk.tar')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
image_path = args.image_path
labels_ah = args.labels_ah
labels_vk = args.labels_vk
if args.out_dir is None:
out_dir = osp.join('data', 'STARE')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(out_dir)
mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz'))
mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files'))
print('Extracting stare-images.tar...')
with tarfile.open(image_path) as f:
f.extractall(osp.join(tmp_dir, 'gz'))
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
un_gz(
osp.join(tmp_dir, 'gz', filename),
osp.join(tmp_dir, 'files',
osp.splitext(filename)[0]))
now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \
'len(os.listdir(now_dir)) != {}'.format(STARE_LEN)
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img,
osp.join(out_dir, 'images', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Removing the temporary files...')
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz'))
mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files'))
print('Extracting labels-ah.tar...')
with tarfile.open(labels_ah) as f:
f.extractall(osp.join(tmp_dir, 'gz'))
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
un_gz(
osp.join(tmp_dir, 'gz', filename),
osp.join(tmp_dir, 'files',
osp.splitext(filename)[0]))
now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \
'len(os.listdir(now_dir)) != {}'.format(STARE_LEN)
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename))
# The annotation img should be divided by 128, because some of
# the annotation imgs are not standard. We should set a threshold
# to convert the nonstandard annotation imgs. The value divided by
# 128 equivalent to '1 if value >= 128 else 0'
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Removing the temporary files...')
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz'))
mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files'))
print('Extracting labels-vk.tar...')
with tarfile.open(labels_vk) as f:
f.extractall(osp.join(tmp_dir, 'gz'))
for filename in os.listdir(osp.join(tmp_dir, 'gz')):
un_gz(
osp.join(tmp_dir, 'gz', filename),
osp.join(tmp_dir, 'files',
osp.splitext(filename)[0]))
now_dir = osp.join(tmp_dir, 'files')
assert len(os.listdir(now_dir)) == STARE_LEN, \
'len(os.listdir(now_dir)) != {}'.format(STARE_LEN)
for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'training',
osp.splitext(filename)[0] + '.png'))
for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
img = mmcv.imread(osp.join(now_dir, filename))
mmcv.imwrite(
img[:, :, 0] // 128,
osp.join(out_dir, 'annotations', 'validation',
osp.splitext(filename)[0] + '.png'))
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from PIL import Image
from scipy.io import loadmat
AUG_LEN = 10582
def convert_mat(mat_file, in_dir, out_dir):
data = loadmat(osp.join(in_dir, mat_file))
mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
Image.fromarray(mask).save(seg_filename, 'PNG')
def generate_aug_list(merged_list, excluded_list):
return list(set(merged_list) - set(excluded_list))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmsegmentation format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('aug_path', help='pascal voc aug path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
aug_path = args.aug_path
nproc = args.nproc
if args.out_dir is None:
out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
else:
out_dir = args.out_dir
mmcv.mkdir_or_exist(out_dir)
in_dir = osp.join(aug_path, 'dataset', 'cls')
mmcv.track_parallel_progress(
partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
list(mmcv.scandir(in_dir, suffix='.mat')),
nproc=nproc)
full_aug_list = []
with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
full_aug_list += [line.strip() for line in f]
with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
full_aug_list += [line.strip() for line in f]
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'train.txt')) as f:
ori_train_list = [line.strip() for line in f]
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'val.txt')) as f:
val_list = [line.strip() for line in f]
aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
val_list)
assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
AUG_LEN)
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'trainaug.txt'), 'w') as f:
f.writelines(line + '\n' for line in aug_train_list)
aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
assert len(aug_list) == AUG_LEN - len(
ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
len(ori_train_list))
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
'w') as f:
f.writelines(line + '\n' for line in aug_list)
print('Done!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import shutil
import warnings
from typing import Any, Iterable
import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info
from mmcv.utils import DictAction
from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize
class ONNXRuntimeSegmentor(BaseSegmentor):
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
super(ONNXRuntimeSegmentor, self).__init__()
import onnxruntime as ort
# get the custom op path
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
for name in self.output_names:
self.io_binding.bind_output(name)
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode
self.is_cuda_available = is_cuda_available
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def encode_decode(self, img, img_metas):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
if not self.is_cuda_available:
img = img.detach().cpu()
elif self.device_id >= 0:
img = img.cuda(self.device_id)
device_type = img.device.type
self.io_binding.bind_input(
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=img.shape,
buffer_ptr=img.data_ptr())
self.sess.run_with_iobinding(self.io_binding)
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
# whole might support dynamic reshape
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
class TensorRTSegmentor(BaseSegmentor):
def __init__(self, trt_file: str, cfg: Any, device_id: int):
super(TensorRTSegmentor, self).__init__()
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWraper(
trt_file, input_names=['input'], output_names=['output'])
self.model = model
self.device_id = device_id
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def encode_decode(self, img, img_metas):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
with torch.cuda.device(self.device_id), torch.no_grad():
seg_pred = self.model({'input': img})['output']
seg_pred = seg_pred.detach().cpu().numpy()
# whole might support dynamic reshape
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description='mmseg backend test (and eval)')
parser.add_argument('config', help='test config file path')
parser.add_argument('model', help='Input model file')
parser.add_argument(
'--backend',
help='Backend of the model.',
choices=['onnxruntime', 'tensorrt'])
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
' for generic datasets, and "cityscapes" for Cityscapes')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# init distributed env first, since logger depends on the dist info.
distributed = False
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# load onnx config and meta
cfg.model.train_cfg = None
if args.backend == 'onnxruntime':
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
elif args.backend == 'tensorrt':
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
model.CLASSES = dataset.CLASSES
model.PALETTE = dataset.PALETTE
# clean gpu memory when starting a new evaluation.
torch.cuda.empty_cache()
eval_kwargs = {} if args.eval_options is None else args.eval_options
# Deprecated
efficient_test = eval_kwargs.get('efficient_test', False)
if efficient_test:
warnings.warn(
'``efficient_test=True`` does not have effect in tools/test.py, '
'the evaluation and format results are CPU memory efficient by '
'default')
eval_on_format_results = (
args.eval is not None and 'cityscapes' in args.eval)
if eval_on_format_results:
assert len(args.eval) == 1, 'eval on format results is not ' \
'applicable for metrics other than ' \
'cityscapes'
if args.format_only or eval_on_format_results:
if 'imgfile_prefix' in eval_kwargs:
tmpdir = eval_kwargs['imgfile_prefix']
else:
tmpdir = '.format_cityscapes'
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
mmcv.mkdir_or_exist(tmpdir)
else:
tmpdir = None
model = MMDataParallel(model, device_ids=[0])
results = single_gpu_test(
model,
data_loader,
args.show,
args.show_dir,
False,
args.opacity,
pre_eval=args.eval is not None and not eval_on_format_results,
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
warnings.warn(
'The behavior of ``args.out`` has been changed since MMSeg '
'v0.16, the pickled outputs could be seg map as type of '
'np.array, pre-eval results or file paths for '
'``dataset.format_results()``.')
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if args.eval:
dataset.evaluate(results, args.eval, **eval_kwargs)
if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir)
if __name__ == '__main__':
main()
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
NCCL_P2P_DISABLE=1 \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
NCCL_P2P_DISABLE=1 \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--launcher pytorch ${@:3}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmcv import Config
from mmcv.cnn import get_model_complexity_info
from mmseg.models import build_segmentor
import sys
sys.path.append("..")
import xformer
import pvt
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[2048, 1024],
help='input image size')
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
cfg.model.pretrained = None
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')).cuda()
model.eval()
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_mit(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
for k, v in ckpt.items():
if k.startswith('head'):
continue
# patch embedding conversion
elif k.startswith('patch_embed'):
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
new_v = v
if 'proj.' in new_k:
new_k = new_k.replace('proj.', 'projection.')
# transformer encoder layer conversion
elif k.startswith('block'):
stage_i = int(k.split('.')[0].replace('block', ''))
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
new_v = v
if 'attn.q.' in new_k:
sub_item_k = k.replace('q.', 'kv.')
new_k = new_k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
elif 'attn.kv.' in new_k:
continue
elif 'attn.proj.' in new_k:
new_k = new_k.replace('proj.', 'attn.out_proj.')
elif 'attn.sr.' in new_k:
new_k = new_k.replace('sr.', 'sr.')
elif 'mlp.' in new_k:
string = f'{new_k}-'
new_k = new_k.replace('mlp.', 'ffn.layers.')
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
new_v = v.reshape((*v.shape, 1, 1))
new_k = new_k.replace('fc1.', '0.')
new_k = new_k.replace('dwconv.dwconv.', '1.')
new_k = new_k.replace('fc2.', '4.')
string += f'{new_k} {v.shape}-{new_v.shape}'
# norm layer conversion
elif k.startswith('norm'):
stage_i = int(k.split('.')[0].replace('norm', ''))
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
new_v = v
else:
new_k = k
new_v = v
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained segformer to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_mit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_swin(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained swin models to'
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_swin(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_vit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_vit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from typing import Iterable, Optional, Union
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import onnxruntime as ort
import torch
from mmcv.ops import get_onnxruntime_op_path
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
save_trt_engine)
from mmseg.apis.inference import LoadImage
from mmseg.datasets import DATASETS
from mmseg.datasets.pipelines import Compose
def get_GiB(x: int):
"""return x GiB."""
return x * (1 << 30)
def _prepare_input_img(img_path: str,
test_pipeline: Iterable[dict],
shape: Optional[Iterable] = None,
rescale_shape: Optional[Iterable] = None) -> dict:
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]
if rescale_shape is not None:
for img_meta in img_metas:
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
return mm_inputs
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
# update img and its meta list
N = img_list[0].size(0)
img_meta = img_meta_list[0][0]
img_shape = img_meta['img_shape']
ori_shape = img_meta['ori_shape']
pad_shape = img_meta['pad_shape']
new_img_meta_list = [[{
'img_shape':
img_shape,
'ori_shape':
ori_shape,
'pad_shape':
pad_shape,
'filename':
img_meta['filename'],
'scale_factor':
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
'flip':
False,
} for _ in range(N)]]
return img_list, new_img_meta_list
def show_result_pyplot(img: Union[str, np.ndarray],
result: np.ndarray,
palette: Optional[Iterable] = None,
fig_size: Iterable[int] = (15, 10),
opacity: float = 0.5,
title: str = '',
block: bool = True):
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
seg = mmcv.imresize(seg, img.shape[:2][::-1])
palette = np.array(palette)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
def onnx2tensorrt(onnx_file: str,
trt_file: str,
config: dict,
input_config: dict,
fp16: bool = False,
verify: bool = False,
show: bool = False,
dataset: str = 'CityscapesDataset',
workspace_size: int = 1,
verbose: bool = False):
import tensorrt as trt
min_shape = input_config['min_shape']
max_shape = input_config['max_shape']
# create trt engine and wrapper
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_file,
opt_shape_dict,
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
fp16_mode=fp16,
max_workspace_size=max_workspace_size)
save_dir, _ = osp.split(trt_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_trt_engine(trt_engine, trt_file)
print(f'Successfully created TensorRT engine: {trt_file}')
if verify:
inputs = _prepare_input_img(
input_config['input_path'],
config.data.test.pipeline,
shape=min_shape[2:])
imgs = inputs['imgs']
img_metas = inputs['img_metas']
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
# update img_meta
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
if max_shape[0] > 1:
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]
# Get results from ONNXRuntime
ort_custom_op_path = get_onnxruntime_op_path()
session_options = ort.SessionOptions()
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
onnx_output = sess.run(['output'],
{'input': img_list[0].detach().numpy()})[0][0]
# Get results from TensorRT
trt_model = TRTWraper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
trt_output = trt_outputs['output'][0].cpu().detach().numpy()
if show:
dataset = DATASETS.get(dataset)
assert dataset is not None
palette = dataset.PALETTE
show_result_pyplot(
input_config['input_path'],
(onnx_output[0].astype(np.uint8), ),
palette=palette,
title='ONNXRuntime',
block=False)
show_result_pyplot(
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
palette=palette,
title='TensorRT')
np.testing.assert_allclose(
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
print('TensorRT and ONNXRuntime output all close.')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMSegmentation models from ONNX to TensorRT')
parser.add_argument('config', help='Config file of the model')
parser.add_argument('model', help='Path to the input ONNX model')
parser.add_argument(
'--trt-file', type=str, help='Path to the output TensorRT engine')
parser.add_argument(
'--max-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Maximum shape of model input.')
parser.add_argument(
'--min-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Minimum shape of model input.')
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
parser.add_argument(
'--workspace-size',
type=int,
default=1,
help='Max workspace size in GiB')
parser.add_argument(
'--input-img', type=str, default='', help='Image for test')
parser.add_argument(
'--show', action='store_true', help='Whether to show output results')
parser.add_argument(
'--dataset',
type=str,
default='CityscapesDataset',
help='Dataset name')
parser.add_argument(
'--verify',
action='store_true',
help='Verify the outputs of ONNXRuntime and TensorRT')
parser.add_argument(
'--verbose',
action='store_true',
help='Whether to verbose logging messages while creating \
TensorRT engine.')
args = parser.parse_args()
return args
if __name__ == '__main__':
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
args = parse_args()
if not args.input_img:
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
# check arguments
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
assert osp.exists(args.model), \
'ONNX model {} not found.'.format(args.model)
assert args.workspace_size >= 0, 'Workspace size less than 0.'
assert DATASETS.get(args.dataset) is not None, \
'Dataset {} does not found.'.format(args.dataset)
for max_value, min_value in zip(args.max_shape, args.min_shape):
assert max_value >= min_value, \
'max_shape should be larger than min shape'
input_config = {
'min_shape': args.min_shape,
'max_shape': args.max_shape,
'input_path': args.input_img
}
cfg = mmcv.Config.fromfile(args.config)
onnx2tensorrt(
args.model,
args.trt_file,
cfg,
input_config,
fp16=args.fp16,
verify=args.verify,
show=args.show,
dataset=args.dataset,
workspace_size=args.workspace_size,
verbose=args.verbose)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment