Unverified Commit db39fd4a authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Enhance] Use Inferencer to implement Demo (#2763)

parent f4c032e4
...@@ -2,117 +2,83 @@ ...@@ -2,117 +2,83 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 25,
"source": [
"from mmdet3d.apis import inference_detector, init_model\n",
"from mmdet3d.registry import VISUALIZERS\n",
"from mmdet3d.utils import register_all_modules"
],
"outputs": [],
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"is_executing": false "is_executing": false
} }
} },
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# register all modules in mmdet3d into the registries\n",
"register_all_modules()"
],
"outputs": [], "outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [ "source": [
"config_file = '../configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py'\n", "from mmdet3d.apis import LidarDet3DInferencer"
"# download the checkpoint from model zoo and put it in `checkpoints/`\n", ]
"checkpoint_file = '../work_dirs/second/epoch_40.pth'"
],
"outputs": [],
"metadata": {
"pycharm": {
"is_executing": false
}
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"source": [ "metadata": {},
"# build the model from a config file and a checkpoint file\n",
"model = init_model(config_file, checkpoint_file, device='cuda:0')"
],
"outputs": [], "outputs": [],
"metadata": {} "source": [
"# initialize inferencer\n",
"inferencer = LidarDet3DInferencer('pointpillars_kitti-3class')"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"source": [
"# init visualizer\n",
"visualizer = VISUALIZERS.build(model.cfg.visualizer)\n",
"visualizer.dataset_meta = {\n",
" 'CLASSES': model.CLASSES,\n",
" 'PALETTE': model.PALETTE\n",
"}"
],
"outputs": [],
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"is_executing": false "is_executing": false
} }
} },
"outputs": [],
"source": [
"# inference\n",
"inputs = dict(points='./data/kitti/000008.bin')\n",
"inferencer(inputs)"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"source": [ "metadata": {},
"# test a single sample\n",
"pcd = './data/kitti/000008.bin'\n",
"result, data = inference_detector(model, pcd)\n",
"points = data['inputs']['points']\n",
"data_input = dict(points=points)"
],
"outputs": [], "outputs": [],
"metadata": { "source": [
"pycharm": { "# inference and visualize\n",
"is_executing": false "# NOTE: use the `Esc` key to exit Open3D window in Jupyter Notebook Environment\n",
} "inferencer(inputs, show=True)"
} ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"source": [ "metadata": {},
"# show the results\n",
"out_dir = './'\n",
"visualizer.add_datasample(\n",
" 'result',\n",
" data_input,\n",
" data_sample=result,\n",
" draw_gt=False,\n",
" show=True,\n",
" wait_time=0,\n",
" out_file=out_dir,\n",
" vis_task='det')"
],
"outputs": [], "outputs": [],
"metadata": { "source": [
"pycharm": { "# If your operating environment does not have a display device,\n",
"is_executing": false "# (e.g. a remote server), you can save the predictions and visualize\n",
} "# them in local devices.\n",
} "inferencer(inputs, show=False, out_dir='./remote_outputs')\n",
"\n",
"# Simulate the migration process\n",
"%mv ./remote_outputs ./local_outputs\n",
"\n",
"# Visualize the predictions from the saved files\n",
"# NOTE: use the `Esc` key to exit Open3D window in Jupyter Notebook Environment\n",
"local_inferencer = LidarDet3DInferencer('pointpillars_kitti-3class')\n",
"inputs = local_inferencer._inputs_to_list(inputs)\n",
"local_inferencer.visualize_preds_fromfile(inputs, ['local_outputs/preds/000008.json'], show=True)"
]
} }
], ],
"metadata": { "metadata": {
"interpreter": {
"hash": "a0c343fece975dd89087e8c2194dd4d3db28d7000f1b32ed9ed9d584dd54dbbe"
},
"kernelspec": { "kernelspec": {
"name": "python3", "display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.7.6 64-bit ('torch1.7-cu10.1': conda)" "language": "python",
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -124,19 +90,16 @@ ...@@ -124,19 +90,16 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.6" "version": "3.9.16"
}, },
"pycharm": { "pycharm": {
"stem_cell": { "stem_cell": {
"cell_type": "raw", "cell_type": "raw",
"source": [],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
} },
"source": []
} }
},
"interpreter": {
"hash": "a0c343fece975dd89087e8c2194dd4d3db28d7000f1b32ed9ed9d584dd54dbbe"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
from argparse import ArgumentParser from argparse import ArgumentParser
import mmcv from mmengine.logging import print_log
from mmdet3d.apis import inference_mono_3d_detector, init_model from mmdet3d.apis import MonoDet3DInferencer
from mmdet3d.registry import VISUALIZERS
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('img', help='image file') parser.add_argument('img', help='Image file')
parser.add_argument('ann', help='ann file') parser.add_argument('infos', help='Infos file with annotations')
parser.add_argument('config', help='Config file') parser.add_argument('model', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('weights', help='Checkpoint file')
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument( parser.add_argument(
...@@ -21,50 +22,77 @@ def parse_args(): ...@@ -21,50 +22,77 @@ def parse_args():
default='CAM_BACK', default='CAM_BACK',
help='choose camera type to inference') help='choose camera type to inference')
parser.add_argument( parser.add_argument(
'--score-thr', type=float, default=0.30, help='bbox score threshold') '--pred-score-thr',
type=float,
default=0.3,
help='bbox score threshold')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir',
type=str,
default='outputs',
help='Output directory of prediction and visualization results.')
parser.add_argument( parser.add_argument(
'--show', '--show',
action='store_true', action='store_true',
help='show online visualization results') help='Show online visualization results')
parser.add_argument(
'--wait-time',
type=float,
default=-1,
help='The interval of show (s). Demo will be blocked in showing'
'results, if wait_time is -1. Defaults to -1.')
parser.add_argument( parser.add_argument(
'--snapshot', '--no-save-vis',
action='store_true', action='store_true',
help='whether to save online visualization results') help='Do not save detection visualization results')
args = parser.parse_args() parser.add_argument(
return args '--no-save-pred',
action='store_true',
help='Do not save detection prediction results')
parser.add_argument(
'--print-result',
action='store_true',
help='Whether to print the results.')
call_args = vars(parser.parse_args())
call_args['inputs'] = dict(
img=call_args.pop('img'), infos=call_args.pop('infos'))
call_args.pop('cam_type')
if call_args['no_save_vis'] and call_args['no_save_pred']:
call_args['out_dir'] = ''
init_kws = ['model', 'weights', 'device']
init_args = {}
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)
# NOTE: If your operating environment does not have a display device,
# (e.g. a remote server), you can save the predictions and visualize
# them in local devices.
if os.environ.get('DISPLAY') is None and call_args['show']:
print_log(
'Display device not found. `--show` is forced to False',
logger='current',
level=logging.WARNING)
call_args['show'] = False
def main(args): return init_args, call_args
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
# test a single image def main():
result = inference_mono_3d_detector(model, args.img, args.ann, # TODO: Support inference of point cloud numpy file.
args.cam_type) init_args, call_args = parse_args()
img = mmcv.imread(args.img) inferencer = MonoDet3DInferencer(**init_args)
img = mmcv.imconvert(img, 'bgr', 'rgb') inferencer(**call_args)
data_input = dict(img=img) if call_args['out_dir'] != '' and not (call_args['no_save_vis']
# show the results and call_args['no_save_pred']):
visualizer.add_datasample( print_log(
'result', f'results have been saved at {call_args["out_dir"]}',
data_input, logger='current')
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='mono_det')
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() main()
main(args)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
from argparse import ArgumentParser from argparse import ArgumentParser
import mmcv from mmengine.logging import print_log
from mmdet3d.apis import inference_multi_modality_detector, init_model from mmdet3d.apis import MultiModalityDet3DInferencer
from mmdet3d.registry import VISUALIZERS
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file') parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('img', help='image file') parser.add_argument('img', help='Image file')
parser.add_argument('ann', help='ann file') parser.add_argument('infos', help='Infos file with annotations')
parser.add_argument('config', help='Config file') parser.add_argument('model', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('weights', help='Checkpoint file')
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument( parser.add_argument(
...@@ -22,57 +23,79 @@ def parse_args(): ...@@ -22,57 +23,79 @@ def parse_args():
default='CAM_FRONT', default='CAM_FRONT',
help='choose camera type to inference') help='choose camera type to inference')
parser.add_argument( parser.add_argument(
'--score-thr', type=float, default=0.0, help='bbox score threshold') '--pred-score-thr',
type=float,
default=0.3,
help='bbox score threshold')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir',
type=str,
default='outputs',
help='Output directory of prediction and visualization results.')
parser.add_argument( parser.add_argument(
'--show', '--show',
action='store_true', action='store_true',
help='show online visualization results') help='Show online visualization results')
parser.add_argument(
'--wait-time',
type=float,
default=-1,
help='The interval of show (s). Demo will be blocked in showing'
'results, if wait_time is -1. Defaults to -1.')
parser.add_argument(
'--no-save-vis',
action='store_true',
help='Do not save detection visualization results')
parser.add_argument(
'--no-save-pred',
action='store_true',
help='Do not save detection prediction results')
parser.add_argument( parser.add_argument(
'--snapshot', '--print-result',
action='store_true', action='store_true',
help='whether to save online visualization results') help='Whether to print the results.')
args = parser.parse_args() call_args = vars(parser.parse_args())
return args
call_args['inputs'] = dict(
points=call_args.pop('pcd'),
img=call_args.pop('img'),
infos=call_args.pop('infos'))
call_args.pop('cam_type')
if call_args['no_save_vis'] and call_args['no_save_pred']:
call_args['out_dir'] = ''
init_kws = ['model', 'weights', 'device']
init_args = {}
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)
# NOTE: If your operating environment does not have a display device,
# (e.g. a remote server), you can save the predictions and visualize
# them in local devices.
if os.environ.get('DISPLAY') is None and call_args['show']:
print_log(
'Display device not found. `--show` is forced to False',
logger='current',
level=logging.WARNING)
call_args['show'] = False
return init_args, call_args
def main(args):
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# init visualizer def main():
visualizer = VISUALIZERS.build(model.cfg.visualizer) # TODO: Support inference of point cloud numpy file.
visualizer.dataset_meta = model.dataset_meta init_args, call_args = parse_args()
# test a single image and point cloud sample inferencer = MultiModalityDet3DInferencer(**init_args)
result, data = inference_multi_modality_detector(model, args.pcd, args.img, inferencer(**call_args)
args.ann, args.cam_type)
points = data['inputs']['points']
if isinstance(result.img_path, list):
img = []
for img_path in result.img_path:
single_img = mmcv.imread(img_path)
single_img = mmcv.imconvert(single_img, 'bgr', 'rgb')
img.append(single_img)
else:
img = mmcv.imread(result.img_path)
img = mmcv.imconvert(img, 'bgr', 'rgb')
data_input = dict(points=points, img=img)
# show the results if call_args['out_dir'] != '' and not (call_args['no_save_vis']
visualizer.add_datasample( and call_args['no_save_pred']):
'result', print_log(
data_input, f'results have been saved at {call_args["out_dir"]}',
data_sample=result, logger='current')
draw_gt=False,
show=args.show,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='multi-modality_det')
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() main()
main(args)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
from argparse import ArgumentParser from argparse import ArgumentParser
from mmdet3d.apis import inference_detector, init_model from mmengine.logging import print_log
from mmdet3d.registry import VISUALIZERS
from mmdet3d.apis import LidarDet3DInferencer
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file') parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('config', help='Config file') parser.add_argument('model', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('weights', help='Checkpoint file')
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument( parser.add_argument(
'--score-thr', type=float, default=0.0, help='bbox score threshold') '--pred-score-thr',
type=float,
default=0.3,
help='bbox score threshold')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir',
type=str,
default='outputs',
help='Output directory of prediction and visualization results.')
parser.add_argument( parser.add_argument(
'--show', '--show',
action='store_true', action='store_true',
help='show online visualization results') help='Show online visualization results')
parser.add_argument(
'--wait-time',
type=float,
default=-1,
help='The interval of show (s). Demo will be blocked in showing'
'results, if wait_time is -1. Defaults to -1.')
parser.add_argument(
'--no-save-vis',
action='store_true',
help='Do not save detection visualization results')
parser.add_argument( parser.add_argument(
'--snapshot', '--no-save-pred',
action='store_true', action='store_true',
help='whether to save online visualization results') help='Do not save detection prediction results')
args = parser.parse_args() parser.add_argument(
return args '--print-result',
action='store_true',
help='Whether to print the results.')
call_args = vars(parser.parse_args())
call_args['inputs'] = dict(points=call_args.pop('pcd'))
if call_args['no_save_vis'] and call_args['no_save_pred']:
call_args['out_dir'] = ''
def main(args): init_kws = ['model', 'weights', 'device']
init_args = {}
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)
# NOTE: If your operating environment does not have a display device,
# (e.g. a remote server), you can save the predictions and visualize
# them in local devices.
if os.environ.get('DISPLAY') is None and call_args['show']:
print_log(
'Display device not found. `--show` is forced to False',
logger='current',
level=logging.WARNING)
call_args['show'] = False
return init_args, call_args
def main():
# TODO: Support inference of point cloud numpy file. # TODO: Support inference of point cloud numpy file.
# build the model from a config file and a checkpoint file init_args, call_args = parse_args()
model = init_model(args.config, args.checkpoint, device=args.device)
inferencer = LidarDet3DInferencer(**init_args)
# init visualizer inferencer(**call_args)
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta if call_args['out_dir'] != '' and not (call_args['no_save_vis']
and call_args['no_save_pred']):
# test a single point cloud sample print_log(
result, data = inference_detector(model, args.pcd) f'results have been saved at {call_args["out_dir"]}',
points = data['inputs']['points'] logger='current')
data_input = dict(points=points)
# show the results
visualizer.add_datasample(
'result',
data_input,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='lidar_det')
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() main()
main(args)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
from argparse import ArgumentParser from argparse import ArgumentParser
from mmdet3d.apis import inference_segmentor, init_model from mmengine.logging import print_log
from mmdet3d.registry import VISUALIZERS
from mmdet3d.apis import LidarSeg3DInferencer
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file') parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('config', help='Config file') parser.add_argument('model', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('weights', help='Checkpoint file')
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir',
type=str,
default='outputs',
help='Output directory of prediction and visualization results.')
parser.add_argument( parser.add_argument(
'--show', '--show',
action='store_true', action='store_true',
help='show online visualization results') help='Show online visualization results')
parser.add_argument(
'--wait-time',
type=float,
default=-1,
help='The interval of show (s). Demo will be blocked in showing'
'results, if wait_time is -1. Defaults to -1.')
parser.add_argument(
'--no-save-vis',
action='store_true',
help='Do not save detection visualization results')
parser.add_argument( parser.add_argument(
'--snapshot', '--no-save-pred',
action='store_true', action='store_true',
help='whether to save online visualization results') help='Do not save detection prediction results')
args = parser.parse_args() parser.add_argument(
return args '--print-result',
action='store_true',
help='Whether to print the results.')
def main(args): call_args = vars(parser.parse_args())
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device) call_args['inputs'] = dict(points=call_args.pop('pcd'))
# init visualizer if call_args['no_save_vis'] and call_args['no_save_pred']:
visualizer = VISUALIZERS.build(model.cfg.visualizer) call_args['out_dir'] = ''
visualizer.dataset_meta = model.dataset_meta
init_kws = ['model', 'weights', 'device']
# test a single point cloud sample init_args = {}
result, data = inference_segmentor(model, args.pcd) for init_kw in init_kws:
points = data['inputs']['points'] init_args[init_kw] = call_args.pop(init_kw)
data_input = dict(points=points)
# show the results # NOTE: If your operating environment does not have a display device,
visualizer.add_datasample( # (e.g. a remote server), you can save the predictions and visualize
'result', # them in local devices.
data_input, if os.environ.get('DISPLAY') is None and call_args['show']:
data_sample=result, print_log(
draw_gt=False, 'Display device not found. `--show` is forced to False',
show=args.show, logger='current',
wait_time=-1, level=logging.WARNING)
out_file=args.out_dir, call_args['show'] = False
vis_task='lidar_seg')
return init_args, call_args
def main():
# TODO: Support inference of point cloud numpy file.
init_args, call_args = parse_args()
inferencer = LidarSeg3DInferencer(**init_args)
inferencer(**call_args)
if call_args['out_dir'] != '' and not (call_args['no_save_vis']
and call_args['no_save_pred']):
print_log(
f'results have been saved at {call_args["out_dir"]}',
logger='current')
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() main()
main(args)
...@@ -141,6 +141,10 @@ You will see a visualizer interface with point cloud, where bounding boxes are p ...@@ -141,6 +141,10 @@ You will see a visualizer interface with point cloud, where bounding boxes are p
**Note**: **Note**:
If you install MMDetection3D on a remote server without display device, you can leave out the `--show` argument. Demo will still save the predictions to `outputs/pred/000008.json` file.
**Note**:
If you want to input a `.ply` file, you can use the following function and convert it to `.bin` format. Then you can use the converted `.bin` file to run demo. If you want to input a `.ply` file, you can use the following function and convert it to `.bin` format. Then you can use the converted `.bin` file to run demo.
Note that you need to install `pandas` and `plyfile` before using this script. This function can also be used for data preprocessing for training `ply data`. Note that you need to install `pandas` and `plyfile` before using this script. This function can also be used for data preprocessing for training `ply data`.
......
...@@ -139,6 +139,10 @@ python demo/pcd_demo.py demo/data/kitti/000008.bin pointpillars_hv_secfpn_8xb6-1 ...@@ -139,6 +139,10 @@ python demo/pcd_demo.py demo/data/kitti/000008.bin pointpillars_hv_secfpn_8xb6-1
**注意** **注意**
如果你在没有显示设备的服务器上安装 MMDetection3D ,你可以忽略 `--show` 参数。Demo 仍会将预测结果保存到 `outputs/pred/000008.json` 文件中。
**注意**
如果您想输入一个 `.ply` 文件,您可以使用如下函数将它转换成 `.bin` 格式。然后您可以使用转化的 `.bin` 文件来运行样例。请注意在使用此脚本之前,您需要安装 `pandas``plyfile`。这个函数也可以用于训练 `ply 数据`时作为数据预处理来使用。 如果您想输入一个 `.ply` 文件,您可以使用如下函数将它转换成 `.bin` 格式。然后您可以使用转化的 `.bin` 文件来运行样例。请注意在使用此脚本之前,您需要安装 `pandas``plyfile`。这个函数也可以用于训练 `ply 数据`时作为数据预处理来使用。
```python ```python
......
...@@ -7,7 +7,7 @@ from mmengine.utils import digit_version ...@@ -7,7 +7,7 @@ from mmengine.utils import digit_version
from .version import __version__, version_info from .version import __version__, version_info
mmcv_minimum_version = '2.0.0rc4' mmcv_minimum_version = '2.0.0rc4'
mmcv_maximum_version = '2.1.0' mmcv_maximum_version = '2.2.0'
mmcv_version = digit_version(mmcv.__version__) mmcv_version = digit_version(mmcv.__version__)
mmengine_minimum_version = '0.8.0' mmengine_minimum_version = '0.8.0'
......
...@@ -392,7 +392,8 @@ def inference_segmentor(model: nn.Module, pcds: PointsType): ...@@ -392,7 +392,8 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
new_test_pipeline = [] new_test_pipeline = []
for pipeline in test_pipeline: for pipeline in test_pipeline:
if pipeline['type'] != 'LoadAnnotations3D': if pipeline['type'] != 'LoadAnnotations3D' and pipeline[
'type'] != 'PointSegClassMapping':
new_test_pipeline.append(pipeline) new_test_pipeline.append(pipeline)
test_pipeline = Compose(new_test_pipeline) test_pipeline = Compose(new_test_pipeline)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from copy import deepcopy
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmengine
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from mmengine.fileio import (get_file_backend, isdir, join_path, from mmengine import dump, print_log
list_dir_or_file)
from mmengine.infer.infer import BaseInferencer, ModelType from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from rich.progress import track
from mmdet3d.registry import MODELS from mmdet3d.registry import DATASETS, MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample
from mmdet3d.utils import ConfigType from mmdet3d.utils import ConfigType
InstanceList = List[InstanceData] InstanceList = List[InstanceData]
...@@ -44,14 +48,14 @@ class Base3DInferencer(BaseInferencer): ...@@ -44,14 +48,14 @@ class Base3DInferencer(BaseInferencer):
priority is palette -> config -> checkpoint. Defaults to 'none'. priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
preprocess_kwargs: set = set() preprocess_kwargs: set = {'cam_type'}
forward_kwargs: set = set() forward_kwargs: set = set()
visualize_kwargs: set = { visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr', 'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir' 'img_out_dir', 'no_save_vis', 'cam_type_dir'
} }
postprocess_kwargs: set = { postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample' 'print_result', 'pred_out_dir', 'return_datasample', 'no_save_pred'
} }
def __init__(self, def __init__(self,
...@@ -60,10 +64,14 @@ class Base3DInferencer(BaseInferencer): ...@@ -60,10 +64,14 @@ class Base3DInferencer(BaseInferencer):
device: Optional[str] = None, device: Optional[str] = None,
scope: str = 'mmdet3d', scope: str = 'mmdet3d',
palette: str = 'none') -> None: palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_predicted_frames = 0
self.palette = palette self.palette = palette
init_default_scope(scope) init_default_scope(scope)
super().__init__( super().__init__(
model=model, weights=weights, device=device, scope=scope) model=model, weights=weights, device=device, scope=scope)
self.model = revert_sync_batchnorm(self.model)
def _convert_syncbn(self, cfg: ConfigType): def _convert_syncbn(self, cfg: ConfigType):
"""Convert config's naiveSyncBN to BN. """Convert config's naiveSyncBN to BN.
...@@ -108,56 +116,19 @@ class Base3DInferencer(BaseInferencer): ...@@ -108,56 +116,19 @@ class Base3DInferencer(BaseInferencer):
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE'] model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
test_dataset_cfg = deepcopy(cfg.test_dataloader.dataset)
# lazy init. We only need the metainfo.
test_dataset_cfg['lazy_init'] = True
metainfo = DATASETS.build(test_dataset_cfg).metainfo
cfg_palette = metainfo.get('palette', None)
if cfg_palette is not None:
model.dataset_meta['palette'] = cfg_palette
model.cfg = cfg # save the config in the model for convenience model.cfg = cfg # save the config in the model for convenience
model.to(device) model.to(device)
model.eval() model.eval()
return model return model
def _inputs_to_list(
self,
inputs: Union[dict, list],
modality_key: Union[str, List[str]] = 'points') -> list:
"""Preprocess the inputs to a list.
Preprocess inputs to a list according to its type:
- list or tuple: return inputs
- dict: the value of key 'points'/`img` is
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
Args:
inputs (Union[dict, list]): Inputs for the inferencer.
modality_key (Union[str, List[str]]): The key of the modality.
Defaults to 'points'.
Returns:
list: List of input for the :meth:`preprocess`.
"""
if isinstance(modality_key, str):
modality_key = [modality_key]
assert set(modality_key).issubset({'points', 'img'})
for key in modality_key:
if isinstance(inputs, dict) and isinstance(inputs[key], str):
img = inputs[key]
backend = get_file_backend(img)
if hasattr(backend, 'isdir') and isdir(img):
# Backends like HttpsBackend do not implement `isdir`, so
# only those backends that implement `isdir` could accept
# the inputs as a directory
filename_list = list_dir_or_file(img, list_dir=False)
inputs = [{
f'{key}': join_path(img, filename)
} for filename in filename_list]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
return list(inputs)
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
"""Returns the index of the transform in a pipeline. """Returns the index of the transform in a pipeline.
...@@ -173,64 +144,81 @@ class Base3DInferencer(BaseInferencer): ...@@ -173,64 +144,81 @@ class Base3DInferencer(BaseInferencer):
visualizer.dataset_meta = self.model.dataset_meta visualizer.dataset_meta = self.model.dataset_meta
return visualizer return visualizer
def _dispatch_kwargs(self,
out_dir: str = '',
cam_type: str = '',
**kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
"""Dispatch kwargs to preprocess(), forward(), visualize() and
postprocess() according to the actual demands.
Args:
out_dir (str): Dir to save the inference results.
cam_type (str): Camera type. Defaults to ''.
**kwargs (dict): Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
forward, visualize and postprocess respectively.
"""
kwargs['img_out_dir'] = out_dir
kwargs['pred_out_dir'] = out_dir
if cam_type != '':
kwargs['cam_type_dir'] = cam_type
return super()._dispatch_kwargs(**kwargs)
def __call__(self, def __call__(self,
inputs: InputsType, inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1, batch_size: int = 1,
return_vis: bool = False, return_datasamples: bool = False,
show: bool = False, **kwargs) -> Optional[dict]:
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '',
print_result: bool = False,
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer. """Call the inferencer.
Args: Args:
inputs (InputsType): Inputs for the inferencer. inputs (InputsType): Inputs for the inferencer.
batch_size (int): Batch size. Defaults to 1.
return_datasamples (bool): Whether to return results as return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False. :obj:`BaseDataElement`. Defaults to False.
batch_size (int): Inference batch size. Defaults to 1. **kwargs: Key words arguments passed to :meth:`preprocess`,
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the visualization results in a
popup window. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`. :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``. and ``postprocess_kwargs``.
Returns: Returns:
dict: Inference and visualization results. dict: Inference and visualization results.
""" """
return super().__call__(
inputs, (
return_datasamples, preprocess_kwargs,
batch_size, forward_kwargs,
return_vis=return_vis, visualize_kwargs,
show=show, postprocess_kwargs,
wait_time=wait_time, ) = self._dispatch_kwargs(**kwargs)
draw_pred=draw_pred,
pred_score_thr=pred_score_thr, cam_type = preprocess_kwargs.pop('cam_type', 'CAM2')
img_out_dir=img_out_dir, ori_inputs = self._inputs_to_list(inputs, cam_type=cam_type)
print_result=print_result, inputs = self.preprocess(
pred_out_file=pred_out_file, ori_inputs, batch_size=batch_size, **preprocess_kwargs)
**kwargs) preds = []
results_dict = {'predictions': [], 'visualization': []}
for data in (track(inputs, description='Inference')
if self.show_progress else inputs):
preds.extend(self.forward(data, **forward_kwargs))
visualization = self.visualize(ori_inputs, preds,
**visualize_kwargs)
results = self.postprocess(preds, visualization,
return_datasamples,
**postprocess_kwargs)
results_dict['predictions'].extend(results['predictions'])
if results['visualization'] is not None:
results_dict['visualization'].extend(results['visualization'])
return results_dict
def postprocess( def postprocess(
self, self,
...@@ -238,7 +226,8 @@ class Base3DInferencer(BaseInferencer): ...@@ -238,7 +226,8 @@ class Base3DInferencer(BaseInferencer):
visualization: Optional[List[np.ndarray]] = None, visualization: Optional[List[np.ndarray]] = None,
return_datasample: bool = False, return_datasample: bool = False,
print_result: bool = False, print_result: bool = False,
pred_out_file: str = '', no_save_pred: bool = False,
pred_out_dir: str = '',
) -> Union[ResType, Tuple[ResType, np.ndarray]]: ) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward`` """Process the predictions and visualization results from ``forward``
and ``visualize``. and ``visualize``.
...@@ -258,7 +247,7 @@ class Base3DInferencer(BaseInferencer): ...@@ -258,7 +247,7 @@ class Base3DInferencer(BaseInferencer):
Defaults to False. Defaults to False.
print_result (bool): Whether to print the inference result w/o print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False. visualization to the console. Defaults to False.
pred_out_file (str): File to save the inference results w/o pred_out_dir (str): Directory to save the inference results w/o
visualization. If left as empty, no file will be saved. visualization. If left as empty, no file will be saved.
Defaults to ''. Defaults to ''.
...@@ -273,35 +262,56 @@ class Base3DInferencer(BaseInferencer): ...@@ -273,35 +262,56 @@ class Base3DInferencer(BaseInferencer):
json-serializable dict containing only basic data elements such json-serializable dict containing only basic data elements such
as strings and numbers. as strings and numbers.
""" """
if no_save_pred is True:
pred_out_dir = ''
result_dict = {} result_dict = {}
results = preds results = preds
if not return_datasample: if not return_datasample:
results = [] results = []
for pred in preds: for pred in preds:
result = self.pred2dict(pred) result = self.pred2dict(pred, pred_out_dir)
results.append(result) results.append(result)
elif pred_out_dir != '':
print_log(
'Currently does not support saving datasample '
'when return_datasample is set to True. '
'Prediction results are not saved!',
level=logging.WARNING)
# Add img to the results after printing and dumping
result_dict['predictions'] = results result_dict['predictions'] = results
if print_result: if print_result:
print(result_dict) print(result_dict)
if pred_out_file != '':
mmengine.dump(result_dict, pred_out_file)
result_dict['visualization'] = visualization result_dict['visualization'] = visualization
return result_dict return result_dict
def pred2dict(self, data_sample: InstanceData) -> Dict: # TODO: The data format and fields saved in json need further discussion.
# Maybe should include model name, timestamp, filename, image info etc.
def pred2dict(self,
data_sample: Det3DDataSample,
pred_out_dir: str = '') -> Dict:
"""Extract elements necessary to represent a prediction into a """Extract elements necessary to represent a prediction into a
dictionary. dictionary.
It's better to contain only basic data elements such as strings and It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable. numbers in order to guarantee it's json-serializable.
Args:
data_sample (:obj:`DetDataSample`): Predictions of the model.
pred_out_dir: Dir to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
Returns:
dict: Prediction results.
""" """
result = {} result = {}
if 'pred_instances_3d' in data_sample: if 'pred_instances_3d' in data_sample:
pred_instances_3d = data_sample.pred_instances_3d.numpy() pred_instances_3d = data_sample.pred_instances_3d.numpy()
result = { result = {
'bboxes_3d': pred_instances_3d.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances_3d.labels_3d.tolist(), 'labels_3d': pred_instances_3d.labels_3d.tolist(),
'scores_3d': pred_instances_3d.scores_3d.tolist() 'scores_3d': pred_instances_3d.scores_3d.tolist(),
'bboxes_3d': pred_instances_3d.bboxes_3d.tensor.cpu().tolist()
} }
if 'pred_pts_seg' in data_sample: if 'pred_pts_seg' in data_sample:
...@@ -309,4 +319,28 @@ class Base3DInferencer(BaseInferencer): ...@@ -309,4 +319,28 @@ class Base3DInferencer(BaseInferencer):
result['pts_semantic_mask'] = \ result['pts_semantic_mask'] = \
pred_pts_seg.pts_semantic_mask.tolist() pred_pts_seg.pts_semantic_mask.tolist()
if data_sample.box_mode_3d == Box3DMode.LIDAR:
result['box_type_3d'] = 'LiDAR'
elif data_sample.box_mode_3d == Box3DMode.CAM:
result['box_type_3d'] = 'Camera'
elif data_sample.box_mode_3d == Box3DMode.DEPTH:
result['box_type_3d'] = 'Depth'
if pred_out_dir != '':
if 'lidar_path' in data_sample:
lidar_path = osp.basename(data_sample.lidar_path)
lidar_path = osp.splitext(lidar_path)[0]
out_json_path = osp.join(pred_out_dir, 'preds',
lidar_path + '.json')
elif 'img_path' in data_sample:
img_path = osp.basename(data_sample.img_path)
img_path = osp.splitext(img_path)[0]
out_json_path = osp.join(pred_out_dir, 'preds',
img_path + '.json')
else:
out_json_path = osp.join(
pred_out_dir, 'preds',
f'{str(self.num_visualized_imgs).zfill(8)}.json')
dump(result, out_json_path)
return result return result
...@@ -4,11 +4,16 @@ from typing import Dict, List, Optional, Sequence, Union ...@@ -4,11 +4,16 @@ from typing import Dict, List, Optional, Sequence, Union
import mmengine import mmengine
import numpy as np import numpy as np
import torch
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import ModelType from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS from mmdet3d.registry import INFERENCERS
from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes)
from mmdet3d.utils import ConfigType from mmdet3d.utils import ConfigType
from .base_3d_inferencer import Base3DInferencer from .base_3d_inferencer import Base3DInferencer
...@@ -43,16 +48,6 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -43,16 +48,6 @@ class LidarDet3DInferencer(Base3DInferencer):
priority is palette -> config -> checkpoint. Defaults to 'none'. priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self, def __init__(self,
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
...@@ -69,7 +64,7 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -69,7 +64,7 @@ class LidarDet3DInferencer(Base3DInferencer):
scope=scope, scope=scope,
palette=palette) palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list: def _inputs_to_list(self, inputs: Union[dict, list], **kwargs) -> list:
"""Preprocess the inputs to a list. """Preprocess the inputs to a list.
Preprocess inputs to a list according to its type: Preprocess inputs to a list according to its type:
...@@ -87,7 +82,22 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -87,7 +82,22 @@ class LidarDet3DInferencer(Base3DInferencer):
Returns: Returns:
list: List of input for the :meth:`preprocess`. list: List of input for the :meth:`preprocess`.
""" """
return super()._inputs_to_list(inputs, modality_key='points') if isinstance(inputs, dict) and isinstance(inputs['points'], str):
pcd = inputs['points']
backend = get_file_backend(pcd)
if hasattr(backend, 'isdir') and isdir(pcd):
# Backends like HttpsBackend do not implement `isdir`, so
# only those backends that implement `isdir` could accept
# the inputs as a directory
filename_list = list_dir_or_file(pcd, list_dir=False)
inputs = [{
'points': join_path(pcd, filename)
} for filename in filename_list]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
return list(inputs)
def _init_pipeline(self, cfg: ConfigType) -> Compose: def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.""" """Initialize the test pipeline."""
...@@ -113,9 +123,10 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -113,9 +123,10 @@ class LidarDet3DInferencer(Base3DInferencer):
preds: PredType, preds: PredType,
return_vis: bool = False, return_vis: bool = False,
show: bool = False, show: bool = False,
wait_time: int = 0, wait_time: int = -1,
draw_pred: bool = True, draw_pred: bool = True,
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
no_save_vis: bool = False,
img_out_dir: str = '') -> Union[List[np.ndarray], None]: img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions. """Visualize predictions.
...@@ -126,11 +137,13 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -126,11 +137,13 @@ class LidarDet3DInferencer(Base3DInferencer):
Defaults to False. Defaults to False.
show (bool): Whether to display the image in a popup window. show (bool): Whether to display the image in a popup window.
Defaults to False. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to -1.
draw_pred (bool): Whether to draw predicted bounding boxes. draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True. Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw. pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3. Defaults to 0.3.
no_save_vis (bool): Whether to force not to save prediction
vis results. Defaults to False.
img_out_dir (str): Output directory of visualization results. img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''. If left as empty, no file will be saved. Defaults to ''.
...@@ -138,8 +151,10 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -138,8 +151,10 @@ class LidarDet3DInferencer(Base3DInferencer):
List[np.ndarray] or None: Returns visualization results only if List[np.ndarray] or None: Returns visualization results only if
applicable. applicable.
""" """
if self.visualizer is None or (not show and img_out_dir == '' if no_save_vis is True:
and not return_vis): img_out_dir = ''
if not show and img_out_dir == '' and not return_vis:
return None return None
if getattr(self, 'visualizer') is None: if getattr(self, 'visualizer') is None:
...@@ -160,13 +175,16 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -160,13 +175,16 @@ class LidarDet3DInferencer(Base3DInferencer):
elif isinstance(single_input, np.ndarray): elif isinstance(single_input, np.ndarray):
points = single_input.copy() points = single_input.copy()
pc_num = str(self.num_visualized_frames).zfill(8) pc_num = str(self.num_visualized_frames).zfill(8)
pc_name = f'pc_{pc_num}.png' pc_name = f'{pc_num}.png'
else: else:
raise ValueError('Unsupported input type: ' raise ValueError('Unsupported input type: '
f'{type(single_input)}') f'{type(single_input)}')
o3d_save_path = osp.join(img_out_dir, pc_name) \ if img_out_dir != '' and show:
if img_out_dir != '' else None o3d_save_path = osp.join(img_out_dir, 'vis_lidar', pc_name)
mmengine.mkdir_or_exist(osp.dirname(o3d_save_path))
else:
o3d_save_path = None
data_input = dict(points=points) data_input = dict(points=points)
self.visualizer.add_datasample( self.visualizer.add_datasample(
...@@ -185,3 +203,40 @@ class LidarDet3DInferencer(Base3DInferencer): ...@@ -185,3 +203,40 @@ class LidarDet3DInferencer(Base3DInferencer):
self.num_visualized_frames += 1 self.num_visualized_frames += 1
return results return results
def visualize_preds_fromfile(self, inputs: InputsType, preds: PredType,
**kwargs) -> Union[List[np.ndarray], None]:
"""Visualize predictions from `*.json` files.
Args:
inputs (InputsType): Inputs for the inferencer.
preds (PredType): Predictions of the model.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
data_samples = []
for pred in preds:
pred = mmengine.load(pred)
data_sample = Det3DDataSample()
data_sample.pred_instances_3d = InstanceData()
data_sample.pred_instances_3d.labels_3d = torch.tensor(
pred['labels_3d'])
data_sample.pred_instances_3d.scores_3d = torch.tensor(
pred['scores_3d'])
if pred['box_type_3d'] == 'LiDAR':
data_sample.pred_instances_3d.bboxes_3d = \
LiDARInstance3DBoxes(pred['bboxes_3d'])
elif pred['box_type_3d'] == 'Camera':
data_sample.pred_instances_3d.bboxes_3d = \
CameraInstance3DBoxes(pred['bboxes_3d'])
elif pred['box_type_3d'] == 'Depth':
data_sample.pred_instances_3d.bboxes_3d = \
DepthInstance3DBoxes(pred['bboxes_3d'])
else:
raise ValueError('Unsupported box type: '
f'{pred["box_type_3d"]}')
data_samples.append(data_sample)
return self.visualize(inputs=inputs, preds=data_samples, **kwargs)
...@@ -5,6 +5,8 @@ from typing import Dict, List, Optional, Sequence, Union ...@@ -5,6 +5,8 @@ from typing import Dict, List, Optional, Sequence, Union
import mmengine import mmengine
import numpy as np import numpy as np
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import ModelType from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
...@@ -43,16 +45,6 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -43,16 +45,6 @@ class LidarSeg3DInferencer(Base3DInferencer):
priority is palette -> config -> checkpoint. Defaults to 'none'. priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self, def __init__(self,
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
...@@ -69,7 +61,7 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -69,7 +61,7 @@ class LidarSeg3DInferencer(Base3DInferencer):
scope=scope, scope=scope,
palette=palette) palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list: def _inputs_to_list(self, inputs: Union[dict, list], **kwargs) -> list:
"""Preprocess the inputs to a list. """Preprocess the inputs to a list.
Preprocess inputs to a list according to its type: Preprocess inputs to a list according to its type:
...@@ -87,7 +79,22 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -87,7 +79,22 @@ class LidarSeg3DInferencer(Base3DInferencer):
Returns: Returns:
list: List of input for the :meth:`preprocess`. list: List of input for the :meth:`preprocess`.
""" """
return super()._inputs_to_list(inputs, modality_key='points') if isinstance(inputs, dict) and isinstance(inputs['points'], str):
pcd = inputs['points']
backend = get_file_backend(pcd)
if hasattr(backend, 'isdir') and isdir(pcd):
# Backends like HttpsBackend do not implement `isdir`, so
# only those backends that implement `isdir` could accept
# the inputs as a directory
filename_list = list_dir_or_file(pcd, list_dir=False)
inputs = [{
'points': join_path(pcd, filename)
} for filename in filename_list]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
return list(inputs)
def _init_pipeline(self, cfg: ConfigType) -> Compose: def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.""" """Initialize the test pipeline."""
...@@ -124,6 +131,7 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -124,6 +131,7 @@ class LidarSeg3DInferencer(Base3DInferencer):
wait_time: int = 0, wait_time: int = 0,
draw_pred: bool = True, draw_pred: bool = True,
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
no_save_vis: bool = False,
img_out_dir: str = '') -> Union[List[np.ndarray], None]: img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions. """Visualize predictions.
...@@ -139,6 +147,7 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -139,6 +147,7 @@ class LidarSeg3DInferencer(Base3DInferencer):
Defaults to True. Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw. pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3. Defaults to 0.3.
no_save_vis (bool): Whether to save visualization results.
img_out_dir (str): Output directory of visualization results. img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''. If left as empty, no file will be saved. Defaults to ''.
...@@ -146,8 +155,10 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -146,8 +155,10 @@ class LidarSeg3DInferencer(Base3DInferencer):
List[np.ndarray] or None: Returns visualization results only if List[np.ndarray] or None: Returns visualization results only if
applicable. applicable.
""" """
if self.visualizer is None or (not show and img_out_dir == '' if no_save_vis is True:
and not return_vis): img_out_dir = ''
if not show and img_out_dir == '' and not return_vis:
return None return None
if getattr(self, 'visualizer') is None: if getattr(self, 'visualizer') is None:
...@@ -168,13 +179,16 @@ class LidarSeg3DInferencer(Base3DInferencer): ...@@ -168,13 +179,16 @@ class LidarSeg3DInferencer(Base3DInferencer):
elif isinstance(single_input, np.ndarray): elif isinstance(single_input, np.ndarray):
points = single_input.copy() points = single_input.copy()
pc_num = str(self.num_visualized_frames).zfill(8) pc_num = str(self.num_visualized_frames).zfill(8)
pc_name = f'pc_{pc_num}.png' pc_name = f'{pc_num}.png'
else: else:
raise ValueError('Unsupported input type: ' raise ValueError('Unsupported input type: '
f'{type(single_input)}') f'{type(single_input)}')
o3d_save_path = osp.join(img_out_dir, pc_name) \ if img_out_dir != '' and show:
if img_out_dir != '' else None o3d_save_path = osp.join(img_out_dir, 'vis_lidar', pc_name)
mmengine.mkdir_or_exist(osp.dirname(o3d_save_path))
else:
o3d_save_path = None
data_input = dict(points=points) data_input = dict(points=points)
self.visualizer.add_datasample( self.visualizer.add_datasample(
......
...@@ -6,6 +6,8 @@ import mmcv ...@@ -6,6 +6,8 @@ import mmcv
import mmengine import mmengine
import numpy as np import numpy as np
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import ModelType from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
...@@ -44,16 +46,6 @@ class MonoDet3DInferencer(Base3DInferencer): ...@@ -44,16 +46,6 @@ class MonoDet3DInferencer(Base3DInferencer):
priority is palette -> config -> checkpoint. Defaults to 'none'. priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self, def __init__(self,
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
...@@ -70,7 +62,10 @@ class MonoDet3DInferencer(Base3DInferencer): ...@@ -70,7 +62,10 @@ class MonoDet3DInferencer(Base3DInferencer):
scope=scope, scope=scope,
palette=palette) palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list: def _inputs_to_list(self,
inputs: Union[dict, list],
cam_type='CAM2',
**kwargs) -> list:
"""Preprocess the inputs to a list. """Preprocess the inputs to a list.
Preprocess inputs to a list according to its type: Preprocess inputs to a list according to its type:
...@@ -88,7 +83,79 @@ class MonoDet3DInferencer(Base3DInferencer): ...@@ -88,7 +83,79 @@ class MonoDet3DInferencer(Base3DInferencer):
Returns: Returns:
list: List of input for the :meth:`preprocess`. list: List of input for the :meth:`preprocess`.
""" """
return super()._inputs_to_list(inputs, modality_key='img') if isinstance(inputs, dict):
assert 'infos' in inputs
infos = inputs.pop('infos')
if isinstance(inputs['img'], str):
img = inputs['img']
backend = get_file_backend(img)
if hasattr(backend, 'isdir') and isdir(img):
# Backends like HttpsBackend do not implement `isdir`, so
# only those backends that implement `isdir` could accept
# the inputs as a directory
filename_list = list_dir_or_file(img, list_dir=False)
inputs = [{
'img': join_path(img, filename)
} for filename in filename_list]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
# get cam2img, lidar2cam and lidar2img from infos
info_list = mmengine.load(infos)['data_list']
assert len(info_list) == len(inputs)
for index, input in enumerate(inputs):
data_info = info_list[index]
img_path = data_info['images'][cam_type]['img_path']
if isinstance(input['img'], str) and \
osp.basename(img_path) != osp.basename(input['img']):
raise ValueError(
f'the info file of {img_path} is not provided.')
cam2img = np.asarray(
data_info['images'][cam_type]['cam2img'], dtype=np.float32)
lidar2cam = np.asarray(
data_info['images'][cam_type]['lidar2cam'],
dtype=np.float32)
if 'lidar2img' in data_info['images'][cam_type]:
lidar2img = np.asarray(
data_info['images'][cam_type]['lidar2img'],
dtype=np.float32)
else:
lidar2img = cam2img @ lidar2cam
input['cam2img'] = cam2img
input['lidar2cam'] = lidar2cam
input['lidar2img'] = lidar2img
elif isinstance(inputs, (list, tuple)):
# get cam2img, lidar2cam and lidar2img from infos
for input in inputs:
assert 'infos' in input
infos = input.pop('infos')
info_list = mmengine.load(infos)['data_list']
assert len(info_list) == 1, 'Only support single sample info' \
'in `.pkl`, when inputs is a list.'
data_info = info_list[0]
img_path = data_info['images'][cam_type]['img_path']
if isinstance(input['img'], str) and \
osp.basename(img_path) != osp.basename(input['img']):
raise ValueError(
f'the info file of {img_path} is not provided.')
cam2img = np.asarray(
data_info['images'][cam_type]['cam2img'], dtype=np.float32)
lidar2cam = np.asarray(
data_info['images'][cam_type]['lidar2cam'],
dtype=np.float32)
if 'lidar2img' in data_info['images'][cam_type]:
lidar2img = np.asarray(
data_info['images'][cam_type]['lidar2img'],
dtype=np.float32)
else:
lidar2img = cam2img @ lidar2cam
input['cam2img'] = cam2img
input['lidar2cam'] = lidar2cam
input['lidar2img'] = lidar2img
return list(inputs)
def _init_pipeline(self, cfg: ConfigType) -> Compose: def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.""" """Initialize the test pipeline."""
...@@ -110,7 +177,9 @@ class MonoDet3DInferencer(Base3DInferencer): ...@@ -110,7 +177,9 @@ class MonoDet3DInferencer(Base3DInferencer):
wait_time: int = 0, wait_time: int = 0,
draw_pred: bool = True, draw_pred: bool = True,
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
img_out_dir: str = '') -> Union[List[np.ndarray], None]: no_save_vis: bool = False,
img_out_dir: str = '',
cam_type_dir: str = 'CAM2') -> Union[List[np.ndarray], None]:
"""Visualize predictions. """Visualize predictions.
Args: Args:
...@@ -125,15 +194,19 @@ class MonoDet3DInferencer(Base3DInferencer): ...@@ -125,15 +194,19 @@ class MonoDet3DInferencer(Base3DInferencer):
Defaults to True. Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw. pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3. Defaults to 0.3.
no_save_vis (bool): Whether to save visualization results.
img_out_dir (str): Output directory of visualization results. img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''. If left as empty, no file will be saved. Defaults to ''.
cam_type_dir (str): Camera type directory. Defaults to 'CAM2'.
Returns: Returns:
List[np.ndarray] or None: Returns visualization results only if List[np.ndarray] or None: Returns visualization results only if
applicable. applicable.
""" """
if self.visualizer is None or (not show and img_out_dir == '' if no_save_vis is True:
and not return_vis): img_out_dir = ''
if not show and img_out_dir == '' and not return_vis:
return None return None
if getattr(self, 'visualizer') is None: if getattr(self, 'visualizer') is None:
...@@ -156,8 +229,8 @@ class MonoDet3DInferencer(Base3DInferencer): ...@@ -156,8 +229,8 @@ class MonoDet3DInferencer(Base3DInferencer):
raise ValueError('Unsupported input type: ' raise ValueError('Unsupported input type: '
f"{type(single_input['img'])}") f"{type(single_input['img'])}")
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ out_file = osp.join(img_out_dir, 'vis_camera', cam_type_dir,
else None img_name) if img_out_dir != '' else None
data_input = dict(img=img) data_input = dict(img=img)
self.visualizer.add_datasample( self.visualizer.add_datasample(
......
...@@ -7,6 +7,8 @@ import mmcv ...@@ -7,6 +7,8 @@ import mmcv
import mmengine import mmengine
import numpy as np import numpy as np
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import ModelType from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
...@@ -44,16 +46,6 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -44,16 +46,6 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
palette (str): The palette of visualization. Defaults to 'none'. palette (str): The palette of visualization. Defaults to 'none'.
""" """
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self, def __init__(self,
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
...@@ -70,7 +62,10 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -70,7 +62,10 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
scope=scope, scope=scope,
palette=palette) palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list: def _inputs_to_list(self,
inputs: Union[dict, list],
cam_type: str = 'CAM2',
**kwargs) -> list:
"""Preprocess the inputs to a list. """Preprocess the inputs to a list.
Preprocess inputs to a list according to its type: Preprocess inputs to a list according to its type:
...@@ -88,7 +83,86 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -88,7 +83,86 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
Returns: Returns:
list: List of input for the :meth:`preprocess`. list: List of input for the :meth:`preprocess`.
""" """
return super()._inputs_to_list(inputs, modality_key=['points', 'img']) if isinstance(inputs, dict):
assert 'infos' in inputs
infos = inputs.pop('infos')
if isinstance(inputs['img'], str):
img, pcd = inputs['img'], inputs['points']
backend = get_file_backend(img)
if hasattr(backend, 'isdir') and isdir(img) and isdir(pcd):
# Backends like HttpsBackend do not implement `isdir`, so
# only those backends that implement `isdir` could accept
# the inputs as a directory
img_filename_list = list_dir_or_file(
img, list_dir=False, suffix=['.png', '.jpg'])
pcd_filename_list = list_dir_or_file(
pcd, list_dir=False, suffix='.bin')
assert len(img_filename_list) == len(pcd_filename_list)
inputs = [{
'img': join_path(img, img_filename),
'points': join_path(pcd, pcd_filename)
} for pcd_filename, img_filename in zip(
pcd_filename_list, img_filename_list)]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
# get cam2img, lidar2cam and lidar2img from infos
info_list = mmengine.load(infos)['data_list']
assert len(info_list) == len(inputs)
for index, input in enumerate(inputs):
data_info = info_list[index]
img_path = data_info['images'][cam_type]['img_path']
if isinstance(input['img'], str) and \
osp.basename(img_path) != osp.basename(input['img']):
raise ValueError(
f'the info file of {img_path} is not provided.')
cam2img = np.asarray(
data_info['images'][cam_type]['cam2img'], dtype=np.float32)
lidar2cam = np.asarray(
data_info['images'][cam_type]['lidar2cam'],
dtype=np.float32)
if 'lidar2img' in data_info['images'][cam_type]:
lidar2img = np.asarray(
data_info['images'][cam_type]['lidar2img'],
dtype=np.float32)
else:
lidar2img = cam2img @ lidar2cam
input['cam2img'] = cam2img
input['lidar2cam'] = lidar2cam
input['lidar2img'] = lidar2img
elif isinstance(inputs, (list, tuple)):
# get cam2img, lidar2cam and lidar2img from infos
for input in inputs:
assert 'infos' in input
infos = input.pop('infos')
info_list = mmengine.load(infos)['data_list']
assert len(info_list) == 1, 'Only support single sample' \
'info in `.pkl`, when input is a list.'
data_info = info_list[0]
img_path = data_info['images'][cam_type]['img_path']
if isinstance(input['img'], str) and \
osp.basename(img_path) != osp.basename(input['img']):
raise ValueError(
f'the info file of {img_path} is not provided.')
cam2img = np.asarray(
data_info['images'][cam_type]['cam2img'], dtype=np.float32)
lidar2cam = np.asarray(
data_info['images'][cam_type]['lidar2cam'],
dtype=np.float32)
if 'lidar2img' in data_info['images'][cam_type]:
lidar2img = np.asarray(
data_info['images'][cam_type]['lidar2img'],
dtype=np.float32)
else:
lidar2img = cam2img @ lidar2cam
input['cam2img'] = cam2img
input['lidar2cam'] = lidar2cam
input['lidar2img'] = lidar2img
return list(inputs)
def _init_pipeline(self, cfg: ConfigType) -> Compose: def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.""" """Initialize the test pipeline."""
...@@ -144,7 +218,9 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -144,7 +218,9 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
wait_time: int = 0, wait_time: int = 0,
draw_pred: bool = True, draw_pred: bool = True,
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
img_out_dir: str = '') -> Union[List[np.ndarray], None]: no_save_vis: bool = False,
img_out_dir: str = '',
cam_type_dir: str = 'CAM2') -> Union[List[np.ndarray], None]:
"""Visualize predictions. """Visualize predictions.
Args: Args:
...@@ -157,6 +233,7 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -157,6 +233,7 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes. draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True. Defaults to True.
no_save_vis (bool): Whether to save visualization results.
pred_score_thr (float): Minimum score of bboxes to draw. pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3. Defaults to 0.3.
img_out_dir (str): Output directory of visualization results. img_out_dir (str): Output directory of visualization results.
...@@ -166,8 +243,10 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -166,8 +243,10 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
List[np.ndarray] or None: Returns visualization results only if List[np.ndarray] or None: Returns visualization results only if
applicable. applicable.
""" """
if self.visualizer is None or (not show and img_out_dir == '' if no_save_vis is True:
and not return_vis): img_out_dir = ''
if not show and img_out_dir == '' and not return_vis:
return None return None
if getattr(self, 'visualizer') is None: if getattr(self, 'visualizer') is None:
...@@ -188,13 +267,16 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -188,13 +267,16 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
elif isinstance(points_input, np.ndarray): elif isinstance(points_input, np.ndarray):
points = points_input.copy() points = points_input.copy()
pc_num = str(self.num_visualized_frames).zfill(8) pc_num = str(self.num_visualized_frames).zfill(8)
pc_name = f'pc_{pc_num}.png' pc_name = f'{pc_num}.png'
else: else:
raise ValueError('Unsupported input type: ' raise ValueError('Unsupported input type: '
f'{type(points_input)}') f'{type(points_input)}')
o3d_save_path = osp.join(img_out_dir, pc_name) \ if img_out_dir != '' and show:
if img_out_dir != '' else None o3d_save_path = osp.join(img_out_dir, 'vis_lidar', pc_name)
mmengine.mkdir_or_exist(osp.dirname(o3d_save_path))
else:
o3d_save_path = None
img_input = single_input['img'] img_input = single_input['img']
if isinstance(single_input['img'], str): if isinstance(single_input['img'], str):
...@@ -210,8 +292,8 @@ class MultiModalityDet3DInferencer(Base3DInferencer): ...@@ -210,8 +292,8 @@ class MultiModalityDet3DInferencer(Base3DInferencer):
raise ValueError('Unsupported input type: ' raise ValueError('Unsupported input type: '
f'{type(img_input)}') f'{type(img_input)}')
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ out_file = osp.join(img_out_dir, 'vis_camera', cam_type_dir,
else None img_name) if img_out_dir != '' else None
data_input = dict(points=points, img=img) data_input = dict(points=points, img=img)
self.visualizer.add_datasample( self.visualizer.add_datasample(
......
...@@ -1153,7 +1153,6 @@ class MonoDet3DInferencerLoader(BaseTransform): ...@@ -1153,7 +1153,6 @@ class MonoDet3DInferencerLoader(BaseTransform):
Added keys: Added keys:
- img - img
- cam2img
- box_type_3d - box_type_3d
- box_mode_3d - box_mode_3d
...@@ -1176,32 +1175,19 @@ class MonoDet3DInferencerLoader(BaseTransform): ...@@ -1176,32 +1175,19 @@ class MonoDet3DInferencerLoader(BaseTransform):
dict: The dict contains loaded image and meta information. dict: The dict contains loaded image and meta information.
""" """
box_type_3d, box_mode_3d = get_box_type('camera') box_type_3d, box_mode_3d = get_box_type('camera')
assert 'calib' in single_input and 'img' in single_input, \
"key 'calib' and 'img' must be in input dict"
if isinstance(single_input['calib'], str):
calib_path = single_input['calib']
with open(calib_path, 'r') as f:
lines = f.readlines()
cam2img = np.array([
float(info) for info in lines[0].split(' ')[0:16]
]).reshape([4, 4])
elif isinstance(single_input['calib'], np.ndarray):
cam2img = single_input['calib']
else:
raise ValueError('Unsupported input calib type: '
f"{type(single_input['calib'])}")
if isinstance(single_input['img'], str): if isinstance(single_input['img'], str):
inputs = dict( inputs = dict(
images=dict( images=dict(
CAM_FRONT=dict( CAM_FRONT=dict(
img_path=single_input['img'], cam2img=cam2img)), img_path=single_input['img'],
cam2img=single_input['cam2img'])),
box_mode_3d=box_mode_3d, box_mode_3d=box_mode_3d,
box_type_3d=box_type_3d) box_type_3d=box_type_3d)
elif isinstance(single_input['img'], np.ndarray): elif isinstance(single_input['img'], np.ndarray):
inputs = dict( inputs = dict(
img=single_input['img'], img=single_input['img'],
cam2img=cam2img, cam2img=single_input['cam2img'],
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d) box_mode_3d=box_mode_3d)
else: else:
...@@ -1252,9 +1238,9 @@ class MultiModalityDet3DInferencerLoader(BaseTransform): ...@@ -1252,9 +1238,9 @@ class MultiModalityDet3DInferencerLoader(BaseTransform):
dict: The dict contains loaded image, point cloud and meta dict: The dict contains loaded image, point cloud and meta
information. information.
""" """
assert 'points' in single_input and 'img' in single_input and \ assert 'points' in single_input and 'img' in single_input, \
'calib' in single_input, "key 'points', 'img' and 'calib' must be " "key 'points', 'img' and must be in input dict," \
f'in input dict, but got {single_input}' f'but got {single_input}'
if isinstance(single_input['points'], str): if isinstance(single_input['points'], str):
inputs = dict( inputs = dict(
lidar_points=dict(lidar_path=single_input['points']), lidar_points=dict(lidar_path=single_input['points']),
...@@ -1283,36 +1269,21 @@ class MultiModalityDet3DInferencerLoader(BaseTransform): ...@@ -1283,36 +1269,21 @@ class MultiModalityDet3DInferencerLoader(BaseTransform):
multi_modality_inputs = points_inputs multi_modality_inputs = points_inputs
box_type_3d, box_mode_3d = get_box_type('lidar') box_type_3d, box_mode_3d = get_box_type('lidar')
if isinstance(single_input['calib'], str):
calib = mmengine.load(single_input['calib'])
elif isinstance(single_input['calib'], dict):
calib = single_input['calib']
else:
raise ValueError('Unsupported input calib type: '
f"{type(single_input['calib'])}")
cam2img = np.asarray(calib['cam2img'], dtype=np.float32)
lidar2cam = np.asarray(calib['lidar2cam'], dtype=np.float32)
if 'lidar2cam' in calib:
lidar2img = np.asarray(calib['lidar2img'], dtype=np.float32)
else:
lidar2img = cam2img @ lidar2cam
if isinstance(single_input['img'], str): if isinstance(single_input['img'], str):
inputs = dict( inputs = dict(
img_path=single_input['img'], img_path=single_input['img'],
cam2img=cam2img, cam2img=single_input['cam2img'],
lidar2img=lidar2img, lidar2img=single_input['lidar2img'],
lidar2cam=lidar2cam, lidar2cam=single_input['lidar2cam'],
box_mode_3d=box_mode_3d, box_mode_3d=box_mode_3d,
box_type_3d=box_type_3d) box_type_3d=box_type_3d)
elif isinstance(single_input['img'], np.ndarray): elif isinstance(single_input['img'], np.ndarray):
inputs = dict( inputs = dict(
img=single_input['img'], img=single_input['img'],
cam2img=cam2img, cam2img=single_input['cam2img'],
lidar2img=lidar2img, lidar2img=single_input['lidar2img'],
lidar2cam=lidar2cam, lidar2cam=single_input['lidar2cam'],
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d) box_mode_3d=box_mode_3d)
else: else:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import math import math
import os
import sys import sys
import time import time
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
...@@ -155,7 +156,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -155,7 +156,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if hasattr(self, 'pcd'): if hasattr(self, 'pcd'):
del self.pcd del self.pcd
def _initialize_o3d_vis(self) -> Visualizer: def _initialize_o3d_vis(self, show=True) -> Visualizer:
"""Initialize open3d vis according to frame_cfg. """Initialize open3d vis according to frame_cfg.
Args: Args:
...@@ -176,8 +177,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -176,8 +177,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
o3d_vis.register_key_action_callback(glfw_key_space, o3d_vis.register_key_action_callback(glfw_key_space,
self.space_action_callback) self.space_action_callback)
o3d_vis.register_key_callback(glfw_key_right, self.right_callback) o3d_vis.register_key_callback(glfw_key_right, self.right_callback)
o3d_vis.create_window() if os.environ.get('DISPLAY', None) is not None and show:
self.view_control = o3d_vis.get_view_control() o3d_vis.create_window()
self.view_control = o3d_vis.get_view_control()
return o3d_vis return o3d_vis
@master_only @master_only
...@@ -859,6 +861,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -859,6 +861,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.view_port) self.view_port)
self.flag_exit = not self.o3d_vis.poll_events() self.flag_exit = not self.o3d_vis.poll_events()
self.o3d_vis.update_renderer() self.o3d_vis.update_renderer()
# if not hasattr(self, 'view_control'):
# self.o3d_vis.create_window()
# self.view_control = self.o3d_vis.get_view_control()
self.view_port = \ self.view_port = \
self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501 self.view_control.convert_to_pinhole_camera_parameters() # noqa: E501
if wait_time != -1: if wait_time != -1:
...@@ -976,7 +981,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -976,7 +981,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# For object detection datasets, no palette is saved # For object detection datasets, no palette is saved
palette = self.dataset_meta.get('palette', None) palette = self.dataset_meta.get('palette', None)
ignore_index = self.dataset_meta.get('ignore_index', None) ignore_index = self.dataset_meta.get('ignore_index', None)
if ignore_index is not None and 'gt_pts_seg' in data_sample and vis_task == 'lidar_seg': # noqa: E501 if vis_task == 'lidar_seg' and ignore_index is not None and 'pts_semantic_mask' in data_sample.gt_pts_seg: # noqa: E501
keep_index = data_sample.gt_pts_seg.pts_semantic_mask != ignore_index # noqa: E501 keep_index = data_sample.gt_pts_seg.pts_semantic_mask != ignore_index # noqa: E501
else: else:
keep_index = None keep_index = None
...@@ -986,6 +991,12 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -986,6 +991,12 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
gt_img_data = None gt_img_data = None
pred_img_data = None pred_img_data = None
if not hasattr(self, 'o3d_vis') and vis_task in [
'multi-view_det', 'lidar_det', 'lidar_seg',
'multi-modality_det'
]:
self.o3d_vis = self._initialize_o3d_vis(show=show)
if draw_gt and data_sample is not None: if draw_gt and data_sample is not None:
if 'gt_instances_3d' in data_sample: if 'gt_instances_3d' in data_sample:
gt_data_3d = self._draw_instances_3d( gt_data_3d = self._draw_instances_3d(
...@@ -1083,6 +1094,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -1083,6 +1094,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if drawn_img_3d is not None: if drawn_img_3d is not None:
mmcv.imwrite(drawn_img_3d[..., ::-1], out_file) mmcv.imwrite(drawn_img_3d[..., ::-1], out_file)
if drawn_img is not None: if drawn_img is not None:
mmcv.imwrite(drawn_img[..., ::-1], out_file) mmcv.imwrite(drawn_img[..., ::-1],
out_file[:-4] + '_2d' + out_file[-4:])
else: else:
self.add_image(name, drawn_img_3d, step) self.add_image(name, drawn_img_3d, step)
...@@ -34,7 +34,7 @@ python projects/BEVFusion/setup.py develop ...@@ -34,7 +34,7 @@ python projects/BEVFusion/setup.py develop
Run a demo on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link): Run a demo on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link):
```shell ```shell
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show python projects/BEVFusion/demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
``` ```
### Training commands ### Training commands
......
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import mmcv
from mmdet3d.apis import inference_multi_modality_detector, init_model
from mmdet3d.registry import VISUALIZERS
def parse_args():
parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('img', help='image file')
parser.add_argument('ann', help='ann file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--cam-type',
type=str,
default='CAM_FRONT',
help='choose camera type to inference')
parser.add_argument(
'--score-thr', type=float, default=0.0, help='bbox score threshold')
parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results')
parser.add_argument(
'--show',
action='store_true',
help='show online visualization results')
parser.add_argument(
'--snapshot',
action='store_true',
help='whether to save online visualization results')
args = parser.parse_args()
return args
def main(args):
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
# test a single image and point cloud sample
result, data = inference_multi_modality_detector(model, args.pcd, args.img,
args.ann, args.cam_type)
points = data['inputs']['points']
if isinstance(result.img_path, list):
img = []
for img_path in result.img_path:
single_img = mmcv.imread(img_path)
single_img = mmcv.imconvert(single_img, 'bgr', 'rgb')
img.append(single_img)
else:
img = mmcv.imread(result.img_path)
img = mmcv.imconvert(img, 'bgr', 'rgb')
data_input = dict(points=points, img=img)
# show the results
visualizer.add_datasample(
'result',
data_input,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=-1,
out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='multi-modality_det')
if __name__ == '__main__':
args = parse_args()
main(args)
...@@ -89,7 +89,7 @@ class TestLidarDet3DInferencer(TestCase): ...@@ -89,7 +89,7 @@ class TestLidarDet3DInferencer(TestCase):
inputs = dict(points='tests/data/kitti/training/velodyne/000000.bin'), inputs = dict(points='tests/data/kitti/training/velodyne/000000.bin'),
# img_out_dir # img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(inputs, img_out_dir=tmp_dir) self.inferencer(inputs, out_dir=tmp_dir)
# TODO: For LiDAR-based detection, the saved image only exists when # TODO: For LiDAR-based detection, the saved image only exists when
# show=True. # show=True.
# self.assertTrue(osp.exists(osp.join(tmp_dir, '000000.png'))) # self.assertTrue(osp.exists(osp.join(tmp_dir, '000000.png')))
...@@ -102,11 +102,9 @@ class TestLidarDet3DInferencer(TestCase): ...@@ -102,11 +102,9 @@ class TestLidarDet3DInferencer(TestCase):
res = self.inferencer(inputs, return_datasamples=True) res = self.inferencer(inputs, return_datasamples=True)
self.assertTrue(is_list_of(res['predictions'], Det3DDataSample)) self.assertTrue(is_list_of(res['predictions'], Det3DDataSample))
# pred_out_file # pred_out_dir
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.json') res = self.inferencer(inputs, print_result=True, out_dir=tmp_dir)
res = self.inferencer( dumped_res = mmengine.load(
inputs, print_result=True, pred_out_file=pred_out_file) osp.join(tmp_dir, 'preds', '000000.json'))
dumped_res = mmengine.load(pred_out_file) self.assertEqual(res['predictions'][0], dumped_res)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
...@@ -91,7 +91,7 @@ class TestLiDARSeg3DInferencer(TestCase): ...@@ -91,7 +91,7 @@ class TestLiDARSeg3DInferencer(TestCase):
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin') inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
# img_out_dir # img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(inputs, img_out_dir=tmp_dir) self.inferencer(inputs, out_dir=tmp_dir)
def test_post_processor(self): def test_post_processor(self):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -101,11 +101,9 @@ class TestLiDARSeg3DInferencer(TestCase): ...@@ -101,11 +101,9 @@ class TestLiDARSeg3DInferencer(TestCase):
res = self.inferencer(inputs, return_datasamples=True) res = self.inferencer(inputs, return_datasamples=True)
self.assertTrue(is_list_of(res['predictions'], Det3DDataSample)) self.assertTrue(is_list_of(res['predictions'], Det3DDataSample))
# pred_out_file # pred_out_dir
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.json') res = self.inferencer(inputs, print_result=True, out_dir=tmp_dir)
res = self.inferencer( dumped_res = mmengine.load(
inputs, print_result=True, pred_out_file=pred_out_file) osp.join(tmp_dir, 'preds', 'Area_1_office_2.json'))
dumped_res = mmengine.load(pred_out_file) self.assertEqual(res['predictions'][0], dumped_res)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment