Commit ff2e15b0 authored by ChaimZhu's avatar ChaimZhu
Browse files

[Fix] Fix visualization bug in demo (#1668)

* fix vis

* add the kitti and nuscenes judgemnet

* fix seg vis

* add docstring

* fix comments and ipynb bug
parent 2171463d
...@@ -2,6 +2,7 @@ voxel_size = [0.05, 0.05, 0.1] ...@@ -2,6 +2,7 @@ voxel_size = [0.05, 0.05, 0.1]
model = dict( model = dict(
type='VoxelNet', type='VoxelNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
voxel_layer=dict( voxel_layer=dict(
max_num_points=5, max_num_points=5,
point_cloud_range=[0, -40, -3, 70.4, 40, 1], point_cloud_range=[0, -40, -3, 70.4, 40, 1],
...@@ -43,33 +44,35 @@ model = dict( ...@@ -43,33 +44,35 @@ model = dict(
diff_rad_by_sin=True, diff_rad_by_sin=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict( loss_cls=dict(
type='FocalLoss', type='mmdet.FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict( loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
assigner=[ assigner=[
dict( # for Pedestrian dict( # for Pedestrian
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35, pos_iou_thr=0.35,
neg_iou_thr=0.2, neg_iou_thr=0.2,
min_pos_iou=0.2, min_pos_iou=0.2,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # for Cyclist dict( # for Cyclist
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35, pos_iou_thr=0.35,
neg_iou_thr=0.2, neg_iou_thr=0.2,
min_pos_iou=0.2, min_pos_iou=0.2,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # for Car dict( # for Car
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6, pos_iou_thr=0.6,
neg_iou_thr=0.45, neg_iou_thr=0.45,
......
{"images": [{"file_name": "samples/CAM_BACK/n015-2018-07-24-11-22-45+0800__CAM_BACK__1532402927637525.jpg", "cam_intrinsic": [[809.2209905677063, 0.0, 829.2196003259838], [0.0, 809.2209905677063, 481.77842384512485], [0.0, 0.0, 1.0]]}]}
...@@ -2,17 +2,47 @@ ...@@ -2,17 +2,47 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 1,
"source": [ "source": [
"from mmdet3d.apis import init_model, inference_detector, show_result_meshlab" "from mmdet3d.apis import inference_detector, init_model\n",
"from mmdet3d.registry import VISUALIZERS\n",
"from mmdet3d.utils import register_all_modules"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/home/PJLAB/zhuchenming/mmdet3d_refactor/mmengine/mmengine/model/utils.py:800: UserWarning: Cannot import torch.fx, `merge_dict` is a simple function to merge multiple dicts\n",
" warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function '\n"
]
}
], ],
"outputs": [],
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"is_executing": false "is_executing": false
} }
} }
}, },
{
"cell_type": "code",
"execution_count": 2,
"source": [
"# register all modules in mmdet3d into the registries\n",
"register_all_modules()"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/home/PJLAB/zhuchenming/mmdet3d_refactor/mmdetection3d/mmdet3d/models/backbones/mink_resnet.py:10: UserWarning: Please follow `getting_started.md` to install MinkowskiEngine.`\n",
" 'Please follow `getting_started.md` to install MinkowskiEngine.`')\n"
]
}
],
"metadata": {}
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 8,
...@@ -35,7 +65,59 @@ ...@@ -35,7 +65,59 @@
"# build the model from a config file and a checkpoint file\n", "# build the model from a config file and a checkpoint file\n",
"model = init_model(config_file, checkpoint_file, device='cuda:0')" "model = init_model(config_file, checkpoint_file, device='cuda:0')"
], ],
"outputs": [], "outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/home/PJLAB/zhuchenming/mmdet3d_refactor/mmdetection3d/mmdet3d/models/dense_heads/anchor3d_head.py:93: UserWarning: dir_offset and dir_limit_offset will be depressed and be incorporated into box coder in the future\n",
" 'dir_offset and dir_limit_offset will be depressed and be '\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"local loads checkpoint from path: /home/PJLAB/zhuchenming/checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-3class_20210831_022017-ae782e87.pth\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"# init visualizer\n",
"visualizer = VISUALIZERS.build(model.cfg.visualizer)\n",
"visualizer.dataset_meta = {\n",
" 'CLASSES': model.CLASSES,\n",
" 'PALETTE': model.PALETTE\n",
"}"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/home/PJLAB/zhuchenming/mmdet3d_refactor/mmengine/mmengine/visualization/visualizer.py:167: UserWarning: `Visualizer` backend is not initialized because save_dir is None.\n",
" warnings.warn('`Visualizer` backend is not initialized '\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"446.4pt\" height=\"302.4pt\" viewBox=\"0 0 446.4 302.4\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2022-08-04T17:45:38.225868</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.5.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"axes_1\"/>\n </g>\n</svg>\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAb4AAAEuCAYAAADx63eqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEiUlEQVR4nO3VMQEAIAzAMMC/5+ECjiYK+nXPzAKAivM7AABeMj4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIMT4AUowPgBTjAyDF+ABIueF8BVm9xhwpAAAAAElFTkSuQmCC"
},
"metadata": {
"needs_background": "light"
}
}
],
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"is_executing": false "is_executing": false
...@@ -44,11 +126,13 @@ ...@@ -44,11 +126,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"source": [ "source": [
"# test a single sample\n", "# test a single sample\n",
"pcd = 'kitti_000008.bin'\n", "pcd = './data/kitti/000008.bin'\n",
"result, data = inference_detector(model, pcd)" "result, data = inference_detector(model, pcd)\n",
"points = data['inputs']['points']\n",
"data_input = dict(points=points)"
], ],
"outputs": [], "outputs": [],
"metadata": { "metadata": {
...@@ -59,13 +143,30 @@ ...@@ -59,13 +143,30 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 12,
"source": [ "source": [
"# show the results\n", "# show the results\n",
"out_dir = './'\n", "out_dir = './'\n",
"show_result_meshlab(data, result, out_dir)" "visualizer.add_datasample(\n",
" 'result',\n",
" data_input,\n",
" pred_sample=result,\n",
" show=True,\n",
" wait_time=0,\n",
" out_file=out_dir,\n",
" vis_task='det')"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1;33m[Open3D WARNING] invalid color in PaintUniformColor, clipping to [0, 1]\n",
"\u001b[0;m\u001b[1;33m[Open3D WARNING] invalid color in PaintUniformColor, clipping to [0, 1]\n",
"\u001b[0;m"
]
}
], ],
"outputs": [],
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"is_executing": false "is_executing": false
...@@ -75,9 +176,8 @@ ...@@ -75,9 +176,8 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "name": "python3",
"language": "python", "display_name": "Python 3.7.6 64-bit ('torch1.7-cu10.1': conda)"
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -89,7 +189,7 @@ ...@@ -89,7 +189,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.7" "version": "3.7.6"
}, },
"pycharm": { "pycharm": {
"stem_cell": { "stem_cell": {
...@@ -99,6 +199,9 @@ ...@@ -99,6 +199,9 @@
"collapsed": false "collapsed": false
} }
} }
},
"interpreter": {
"hash": "a0c343fece975dd89087e8c2194dd4d3db28d7000f1b32ed9ed9d584dd54dbbe"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -10,14 +10,19 @@ from mmdet3d.utils import register_all_modules ...@@ -10,14 +10,19 @@ from mmdet3d.utils import register_all_modules
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('image', help='image file') parser.add_argument('img', help='image file')
parser.add_argument('ann', help='ann file') parser.add_argument('ann', help='ann file')
parser.add_argument('config', help='Config file') parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument( parser.add_argument(
'--score-thr', type=float, default=0.15, help='bbox score threshold') '--cam-type',
type=str,
default='CAM_FRONT',
help='choose camera type to inference')
parser.add_argument(
'--score-thr', type=float, default=0.30, help='bbox score threshold')
parser.add_argument( parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results') '--out-dir', type=str, default='demo', help='dir to save results')
parser.add_argument( parser.add_argument(
...@@ -33,7 +38,7 @@ def parse_args(): ...@@ -33,7 +38,7 @@ def parse_args():
def main(args): def main(args):
# register all modules in mmdet into the registries # register all modules in mmdet3d into the registries
register_all_modules() register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
...@@ -47,7 +52,8 @@ def main(args): ...@@ -47,7 +52,8 @@ def main(args):
} }
# test a single image # test a single image
result, data = inference_mono_3d_detector(model, args.image, args.ann) result = inference_mono_3d_detector(model, args.img, args.ann,
args.cam_type)
img = mmcv.imread(args.img) img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb') img = mmcv.imconvert(img, 'bgr', 'rgb')
...@@ -60,9 +66,9 @@ def main(args): ...@@ -60,9 +66,9 @@ def main(args):
pred_sample=result, pred_sample=result,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_file, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='multi_modality-det') vis_task='mono-det')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from argparse import ArgumentParser from argparse import ArgumentParser
import mmcv import mmcv
import numpy as np
from mmdet3d.apis import inference_multi_modality_detector, init_model from mmdet3d.apis import inference_multi_modality_detector, init_model
from mmdet3d.registry import VISUALIZERS from mmdet3d.registry import VISUALIZERS
...@@ -12,12 +11,17 @@ from mmdet3d.utils import register_all_modules ...@@ -12,12 +11,17 @@ from mmdet3d.utils import register_all_modules
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file') parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('image', help='image file') parser.add_argument('img', help='image file')
parser.add_argument('ann', help='ann file') parser.add_argument('ann', help='ann file')
parser.add_argument('config', help='Config file') parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--cam-type',
type=str,
default='CAM_FRONT',
help='choose camera type to inference')
parser.add_argument( parser.add_argument(
'--score-thr', type=float, default=0.0, help='bbox score threshold') '--score-thr', type=float, default=0.0, help='bbox score threshold')
parser.add_argument( parser.add_argument(
...@@ -35,7 +39,7 @@ def parse_args(): ...@@ -35,7 +39,7 @@ def parse_args():
def main(args): def main(args):
# register all modules in mmdet into the registries # register all modules in mmdet3d into the registries
register_all_modules() register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
...@@ -49,14 +53,13 @@ def main(args): ...@@ -49,14 +53,13 @@ def main(args):
} }
# test a single image and point cloud sample # test a single image and point cloud sample
result, data = inference_multi_modality_detector(model, args.pcd, result, data = inference_multi_modality_detector(model, args.pcd, args.img,
args.image, args.ann) args.ann, args.cam_type)
points = data['inputs']['points']
points = np.fromfile(args.pcd, dtype=np.float32)
img = mmcv.imread(args.img) img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb') img = mmcv.imconvert(img, 'bgr', 'rgb')
data_input = dict(points=points, img=img) data_input = dict(points=points, img=img)
# show the results # show the results
visualizer.add_datasample( visualizer.add_datasample(
'result', 'result',
...@@ -64,7 +67,7 @@ def main(args): ...@@ -64,7 +67,7 @@ def main(args):
pred_sample=result, pred_sample=result,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_file, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='multi_modality-det') vis_task='multi_modality-det')
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser from argparse import ArgumentParser
import numpy as np
from mmdet3d.apis import inference_segmentor, init_model from mmdet3d.apis import inference_segmentor, init_model
from mmdet3d.registry import VISUALIZERS from mmdet3d.registry import VISUALIZERS
from mmdet3d.utils import register_all_modules from mmdet3d.utils import register_all_modules
...@@ -30,7 +28,7 @@ def parse_args(): ...@@ -30,7 +28,7 @@ def parse_args():
def main(args): def main(args):
# register all modules in mmdet into the registries # register all modules in mmdet3d into the registries
register_all_modules() register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
...@@ -45,8 +43,7 @@ def main(args): ...@@ -45,8 +43,7 @@ def main(args):
# test a single point cloud sample # test a single point cloud sample
result, data = inference_segmentor(model, args.pcd) result, data = inference_segmentor(model, args.pcd)
points = data['inputs']['points']
points = np.fromfile(args.pcd, dtype=np.float32)
data_input = dict(points=points) data_input = dict(points=points)
# show the results # show the results
visualizer.add_datasample( visualizer.add_datasample(
...@@ -55,8 +52,7 @@ def main(args): ...@@ -55,8 +52,7 @@ def main(args):
pred_sample=result, pred_sample=result,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_file, out_file=args.out_dir,
pred_score_thr=args.score_thr,
vis_task='seg') vis_task='seg')
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser from argparse import ArgumentParser
import numpy as np
from mmdet3d.apis import inference_detector, init_model from mmdet3d.apis import inference_detector, init_model
from mmdet3d.registry import VISUALIZERS from mmdet3d.registry import VISUALIZERS
from mmdet3d.utils import register_all_modules from mmdet3d.utils import register_all_modules
...@@ -32,7 +30,7 @@ def parse_args(): ...@@ -32,7 +30,7 @@ def parse_args():
def main(args): def main(args):
# register all modules in mmdet into the registries # register all modules in mmdet3d into the registries
register_all_modules() register_all_modules()
# TODO: Support inference of point cloud numpy file. # TODO: Support inference of point cloud numpy file.
...@@ -43,14 +41,13 @@ def main(args): ...@@ -43,14 +41,13 @@ def main(args):
visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = { visualizer.dataset_meta = {
'CLASSES': model.CLASSES, 'CLASSES': model.CLASSES,
'PALETTE': model.palette
} }
# test a single point cloud sample # test a single point cloud sample
result, data = inference_detector(model, args.pcd) result, data = inference_detector(model, args.pcd)
points = data['inputs']['points']
points = np.fromfile(args.pcd, dtype=np.float32)
data_input = dict(points=points) data_input = dict(points=points)
# show the results # show the results
visualizer.add_datasample( visualizer.add_datasample(
'result', 'result',
...@@ -58,7 +55,7 @@ def main(args): ...@@ -58,7 +55,7 @@ def main(args):
pred_sample=result, pred_sample=result,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_file, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='det') vis_task='det')
......
...@@ -5,10 +5,10 @@ from os import path as osp ...@@ -5,10 +5,10 @@ from os import path as osp
from typing import Sequence, Union from typing import Sequence, Union
import mmcv import mmcv
import mmengine
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.config import Config
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
...@@ -48,11 +48,10 @@ def init_model(config, checkpoint=None, device='cuda:0'): ...@@ -48,11 +48,10 @@ def init_model(config, checkpoint=None, device='cuda:0'):
nn.Module: The constructed detector. nn.Module: The constructed detector.
""" """
if isinstance(config, str): if isinstance(config, str):
config = mmengine.Config.fromfile(config) config = Config.fromfile(config)
elif not isinstance(config, mmengine.Config): elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, ' raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}') f'but got {type(config)}')
config.model.pretrained = None
convert_SyncBN(config.model) convert_SyncBN(config.model)
config.model.train_cfg = None config.model.train_cfg = None
model = MODELS.build(config.model) model = MODELS.build(config.model)
...@@ -110,8 +109,8 @@ def inference_detector(model: nn.Module, ...@@ -110,8 +109,8 @@ def inference_detector(model: nn.Module,
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
# box_type_3d, box_mode_3d = get_box_type( box_type_3d, box_mode_3d = \
# cfg.test_dataloader.dataset.box_type_3d) get_box_type(cfg.test_dataloader.dataset.box_type_3d)
data = [] data = []
for pcd in pcds: for pcd in pcds:
...@@ -121,19 +120,17 @@ def inference_detector(model: nn.Module, ...@@ -121,19 +120,17 @@ def inference_detector(model: nn.Module,
data_ = dict( data_ = dict(
lidar_points=dict(lidar_path=pcd), lidar_points=dict(lidar_path=pcd),
# for ScanNet demo we need axis_align_matrix # for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)), axis_align_matrix=np.eye(4),
sweeps=[], box_type_3d=box_type_3d,
# set timestamp = 0 box_mode_3d=box_mode_3d)
timestamp=[0])
else: else:
# directly use loaded point cloud # directly use loaded point cloud
data_ = dict( data_ = dict(
points=pcd, points=pcd,
# for ScanNet demo we need axis_align_matrix # for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)), axis_align_matrix=np.eye(4),
sweeps=[], box_type_3d=box_type_3d,
# set timestamp = 0 box_mode_3d=box_mode_3d)
timestamp=[0])
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
data.append(data_) data.append(data_)
...@@ -142,15 +139,16 @@ def inference_detector(model: nn.Module, ...@@ -142,15 +139,16 @@ def inference_detector(model: nn.Module,
results = model.test_step(data) results = model.test_step(data)
if not is_batch: if not is_batch:
return results[0] return results[0], data[0]
else: else:
return results return results, data
def inference_multi_modality_detector(model: nn.Module, def inference_multi_modality_detector(model: nn.Module,
pcds: Union[str, Sequence[str]], pcds: Union[str, Sequence[str]],
imgs: Union[str, Sequence[str]], imgs: Union[str, Sequence[str]],
ann_files: Union[str, Sequence[str]]): ann_file: Union[str, Sequence[str]],
cam_type: str = 'CAM_FRONT'):
"""Inference point cloud with the multi-modality detector. """Inference point cloud with the multi-modality detector.
Args: Args:
...@@ -159,7 +157,11 @@ def inference_multi_modality_detector(model: nn.Module, ...@@ -159,7 +157,11 @@ def inference_multi_modality_detector(model: nn.Module,
Either point cloud files or loaded point cloud. Either point cloud files or loaded point cloud.
imgs (str, Sequence[str]): imgs (str, Sequence[str]):
Either image files or loaded images. Either image files or loaded images.
ann_files (str, Sequence[str]): Annotation files. ann_file (str, Sequence[str]): Annotation files.
cam_type (str): Image of Camera chose to infer.
For kitti dataset, it should be 'CAM_2',
and for nuscenes dataset, it should be
'CAM_FRONT'. Defaults to 'CAM_FRONT'.
Returns: Returns:
:obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]: :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
...@@ -171,12 +173,10 @@ def inference_multi_modality_detector(model: nn.Module, ...@@ -171,12 +173,10 @@ def inference_multi_modality_detector(model: nn.Module,
if isinstance(pcds, (list, tuple)): if isinstance(pcds, (list, tuple)):
is_batch = True is_batch = True
assert isinstance(imgs, (list, tuple)) assert isinstance(imgs, (list, tuple))
assert isinstance(ann_files, (list, tuple)) assert len(pcds) == len(imgs)
assert len(pcds) == len(imgs) == len(ann_files)
else: else:
pcds = [pcds] pcds = [pcds]
imgs = [imgs] imgs = [imgs]
ann_files = [ann_files]
is_batch = False is_batch = False
cfg = model.cfg cfg = model.cfg
...@@ -187,44 +187,57 @@ def inference_multi_modality_detector(model: nn.Module, ...@@ -187,44 +187,57 @@ def inference_multi_modality_detector(model: nn.Module,
box_type_3d, box_mode_3d = \ box_type_3d, box_mode_3d = \
get_box_type(cfg.test_dataloader.dataset.box_type_3d) get_box_type(cfg.test_dataloader.dataset.box_type_3d)
data_list = mmcv.load(ann_file)['data_list']
assert len(imgs) == len(data_list)
data = [] data = []
for index, pcd in enumerate(pcds): for index, pcd in enumerate(pcds):
# get data info containing calib # get data info containing calib
img = imgs[index] img = imgs[index]
ann_file = ann_files[index] data_info = data_list[index]
data_info = mmcv.load(ann_file)[0] img_path = data_info['images'][cam_type]['img_path']
if osp.basename(img_path) != osp.basename(img):
raise ValueError(f'the info file of {img_path} is not provided.')
# TODO: check the name consistency of # TODO: check the name consistency of
# image file and point cloud file # image file and point cloud file
data_ = dict( data_ = dict(
lidar_points=dict(lidar_path=pcd), lidar_points=dict(lidar_path=pcd),
img_path=imgs[index], img_path=img,
img_prefix=osp.dirname(img),
img_info=dict(filename=osp.basename(img)),
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d) box_mode_3d=box_mode_3d)
data_ = test_pipeline(data_)
# LiDAR to image conversion for KITTI dataset # LiDAR to image conversion for KITTI dataset
if box_mode_3d == Box3DMode.LIDAR: if box_mode_3d == Box3DMode.LIDAR:
data_['lidar2img'] = data_info['images']['CAM2']['lidar2img'] data_['lidar2img'] = np.array(
data_info['images'][cam_type]['lidar2img'])
# Depth to image conversion for SUNRGBD dataset # Depth to image conversion for SUNRGBD dataset
elif box_mode_3d == Box3DMode.DEPTH: elif box_mode_3d == Box3DMode.DEPTH:
data_['depth2img'] = data_info['images']['CAM0']['depth2img'] data_['depth2img'] = np.array(
data_info['images'][cam_type]['depth2img'])
data_ = test_pipeline(data_)
data.append(data_) data.append(data_)
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
results = model.test_step(data) results = model.test_step(data)
for index in range(len(data)):
meta_info = data[index]['data_sample'].metainfo
results[index].set_metainfo(meta_info)
if not is_batch: if not is_batch:
return results[0] return results[0], data[0]
else: else:
return results return results, data
def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType, def inference_mono_3d_detector(model: nn.Module,
ann_files: Union[str, Sequence[str]]): imgs: ImagesType,
ann_file: Union[str, Sequence[str]],
cam_type: str = 'CAM_FRONT'):
"""Inference image with the monocular 3D detector. """Inference image with the monocular 3D detector.
Args: Args:
...@@ -232,6 +245,10 @@ def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType, ...@@ -232,6 +245,10 @@ def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType,
imgs (str, Sequence[str]): imgs (str, Sequence[str]):
Either image files or loaded images. Either image files or loaded images.
ann_files (str, Sequence[str]): Annotation files. ann_files (str, Sequence[str]): Annotation files.
cam_type (str): Image of Camera chose to infer.
For kitti dataset, it should be 'CAM_2',
and for nuscenes dataset, it should be
'CAM_FRONT'. Defaults to 'CAM_FRONT'.
Returns: Returns:
:obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]: :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
...@@ -252,23 +269,35 @@ def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType, ...@@ -252,23 +269,35 @@ def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType,
box_type_3d, box_mode_3d = \ box_type_3d, box_mode_3d = \
get_box_type(cfg.test_dataloader.dataset.box_type_3d) get_box_type(cfg.test_dataloader.dataset.box_type_3d)
data_list = mmcv.load(ann_file)
assert len(imgs) == len(data_list)
data = [] data = []
for index, img in enumerate(imgs): for index, img in enumerate(imgs):
ann_file = ann_files[index]
# get data info containing calib # get data info containing calib
data_info = mmcv.load(ann_file)[0] data_info = data_list[index]
img_path = data_info['images'][cam_type]['img_path']
if osp.basename(img_path) != osp.basename(img):
raise ValueError(f'the info file of {img_path} is not provided.')
# replace the img_path in data_info with img
data_info['images'][cam_type]['img_path'] = img
data_ = dict( data_ = dict(
img_path=img,
images=data_info['images'], images=data_info['images'],
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d) box_mode_3d=box_mode_3d)
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
data.append(data_)
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
results = model.test_step(data) results = model.test_step(data)
for index in range(len(data)):
meta_info = data[index]['data_sample'].metainfo
results[index].set_metainfo(meta_info)
if not is_batch: if not is_batch:
return results[0] return results[0]
else: else:
...@@ -301,6 +330,7 @@ def inference_segmentor(model: nn.Module, pcds: PointsType): ...@@ -301,6 +330,7 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
data = [] data = []
# TODO: support load points array
for pcd in pcds: for pcd in pcds:
data_ = dict(lidar_points=dict(lidar_path=pcd)) data_ = dict(lidar_points=dict(lidar_path=pcd))
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
...@@ -311,6 +341,6 @@ def inference_segmentor(model: nn.Module, pcds: PointsType): ...@@ -311,6 +341,6 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
results = model.test_step(data) results = model.test_step(data)
if not is_batch: if not is_batch:
return results[0] return results[0], data[0]
else: else:
return results return results, data
...@@ -2,11 +2,10 @@ ...@@ -2,11 +2,10 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import List, Tuple from typing import List, Tuple
from mmengine.data import PixelData
from mmengine.model import BaseModel from mmengine.model import BaseModel
from torch import Tensor from torch import Tensor
from mmdet3d.structures import Det3DDataSample from mmdet3d.structures import Det3DDataSample, PointData
from mmdet3d.structures.det3d_data_sample import (ForwardResults, from mmdet3d.structures.det3d_data_sample import (ForwardResults,
OptSampleList, SampleList) OptSampleList, SampleList)
from mmdet3d.utils import OptConfigType, OptMultiConfig from mmdet3d.utils import OptConfigType, OptMultiConfig
...@@ -139,7 +138,7 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta): ...@@ -139,7 +138,7 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
"""Placeholder for augmentation test.""" """Placeholder for augmentation test."""
pass pass
def postprocess_result(self, seg_logits_list: List[dict], def postprocess_result(self, seg_pred_list: List[dict],
batch_img_metas: List[dict]) -> list: batch_img_metas: List[dict]) -> list:
""" Convert results list to `Det3DDataSample`. """ Convert results list to `Det3DDataSample`.
Args: Args:
...@@ -150,19 +149,16 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta): ...@@ -150,19 +149,16 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
list[:obj:`Det3DDataSample`]: Segmentation results of the list[:obj:`Det3DDataSample`]: Segmentation results of the
input images. Each Det3DDataSample usually contain: input images. Each Det3DDataSample usually contain:
- ``pred_pts_sem_seg``(PixelData): Prediction of 3D - ``pred_pts_seg``(PixelData): Prediction of 3D
semantic segmentation. semantic segmentation.
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
""" """
predictions = [] predictions = []
for i in range(len(seg_logits_list)): for i in range(len(seg_pred_list)):
img_meta = batch_img_metas[i] img_meta = batch_img_metas[i]
seg_logits = seg_logits_list[i][None], seg_pred = seg_pred_list[i]
seg_pred = seg_logits.argmax(dim=0, keepdim=True)
prediction = Det3DDataSample(**{'metainfo': img_meta}) prediction = Det3DDataSample(**{'metainfo': img_meta})
prediction.set_data( prediction.set_data(
{'pred_pts_sem_seg': PixelData(**{'data': seg_pred})}) {'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
predictions.append(prediction) predictions.append(prediction)
return predictions return predictions
...@@ -41,7 +41,7 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -41,7 +41,7 @@ class EncoderDecoder3D(Base3DSegmentor):
.. code:: text .. code:: text
predict(): inference() -> postprocess_result() predict(): inference() -> postprocess_result()
infercen(): whole_inference()/slide_inference() inference(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder() whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict() encoder_decoder(): extract_feat() -> decode_head.predict()
...@@ -122,32 +122,26 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -122,32 +122,26 @@ class EncoderDecoder3D(Base3DSegmentor):
else: else:
self.loss_regularization = MODELS.build(loss_regularization) self.loss_regularization = MODELS.build(loss_regularization)
def extract_feat(self, batch_inputs_dict: dict) -> List[Tensor]: def extract_feat(self, batch_inputs) -> List[Tensor]:
"""Extract features from points.""" """Extract features from points."""
points = batch_inputs_dict['points'] x = self.backbone(batch_inputs)
stack_points = torch.stack(points)
x = self.backbone(stack_points)
if self.with_neck: if self.with_neck:
x = self.neck(x) x = self.neck(x)
return x return x
def encode_decode(self, batch_inputs_dict: dict, def encode_decode(self, batch_inputs: torch.Tensor,
batch_input_metas: List[dict]) -> List[Tensor]: batch_input_metas: List[dict]) -> List[Tensor]:
"""Encode points with backbone and decode into a semantic segmentation """Encode points with backbone and decode into a semantic segmentation
map of the same size as input. map of the same size as input.
Args: Args:
batch_inputs_dict (dict): Input sample dict which batch_input (torch.Tensor): Input point cloud sample
includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
batch_input_metas (list[dict]): Meta information of each sample. batch_input_metas (list[dict]): Meta information of each sample.
Returns: Returns:
torch.Tensor: Segmentation logits of shape [B, num_classes, N]. torch.Tensor: Segmentation logits of shape [B, num_classes, N].
""" """
x = self.extract_feat(batch_inputs_dict) x = self.extract_feat(batch_inputs)
seg_logits = self.decode_head.predict(x, batch_input_metas, seg_logits = self.decode_head.predict(x, batch_input_metas,
self.test_cfg) self.test_cfg)
return seg_logits return seg_logits
...@@ -481,7 +475,7 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -481,7 +475,7 @@ class EncoderDecoder3D(Base3DSegmentor):
# 3D segmentation requires per-point prediction, so it's impossible # 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points # to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time # therefore, we only support testing one scene every time
seg_pred = [] seg_pred_list = []
batch_input_metas = [] batch_input_metas = []
for data_sample in batch_data_samples: for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo) batch_input_metas.append(data_sample.metainfo)
...@@ -493,10 +487,9 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -493,10 +487,9 @@ class EncoderDecoder3D(Base3DSegmentor):
seg_map = seg_prob.argmax(0) # [N] seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d # to cpu tensor for consistency with det3d
seg_map = seg_map.cpu() seg_map = seg_map.cpu()
seg_pred.append(seg_map) seg_pred_list.append(seg_map)
# warp in dict
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred] return self.postprocess_result(seg_pred_list, batch_input_metas)
return seg_pred
def _forward(self, def _forward(self,
batch_inputs_dict: dict, batch_inputs_dict: dict,
...@@ -519,3 +512,7 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -519,3 +512,7 @@ class EncoderDecoder3D(Base3DSegmentor):
""" """
x = self.extract_feat(batch_inputs_dict) x = self.extract_feat(batch_inputs_dict)
return self.decode_head.forward(x) return self.decode_head.forward(x)
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass
...@@ -21,8 +21,10 @@ from mmengine.data import InstanceData ...@@ -21,8 +21,10 @@ from mmengine.data import InstanceData
from mmengine.visualization.utils import check_type, tensor2ndarray from mmengine.visualization.utils import check_type, tensor2ndarray
from mmdet3d.registry import VISUALIZERS from mmdet3d.registry import VISUALIZERS
from mmdet3d.structures import (BaseInstance3DBoxes, DepthInstance3DBoxes, from mmdet3d.structures import (BaseInstance3DBoxes, CameraInstance3DBoxes,
Det3DDataSample, PointData) Coord3DMode, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes,
PointData)
from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img, from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img,
proj_lidar_bbox3d_to_img, to_depth_mode, write_obj, proj_lidar_bbox3d_to_img, to_depth_mode, write_obj,
write_oriented_bbox) write_oriented_bbox)
...@@ -106,6 +108,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -106,6 +108,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
line_width=line_width, line_width=line_width,
alpha=alpha) alpha=alpha)
self.o3d_vis = self._initialize_o3d_vis(vis_cfg) self.o3d_vis = self._initialize_o3d_vis(vis_cfg)
self.seg_num = 0
def _initialize_o3d_vis(self, vis_cfg) -> tuple: def _initialize_o3d_vis(self, vis_cfg) -> tuple:
"""Build open3d vis according to vis_cfg. """Build open3d vis according to vis_cfg.
...@@ -128,7 +131,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -128,7 +131,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
@master_only @master_only
def set_points(self, def set_points(self,
points: np.ndarray, points: np.ndarray,
vis_task: str, pcd_mode: int = 0,
vis_task: str = 'det',
points_color: Tuple = (0.5, 0.5, 0.5), points_color: Tuple = (0.5, 0.5, 0.5),
points_size: int = 2, points_size: int = 2,
mode: str = 'xyz') -> None: mode: str = 'xyz') -> None:
...@@ -137,6 +141,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -137,6 +141,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
Args: Args:
points (numpy.array, shape=[N, 3+C]): points (numpy.array, shape=[N, 3+C]):
points to visualize. points to visualize.
pcd_mode (int): The point cloud mode (coordinates):
0 represents LiDAR, 1 represents CAMERA, 2
represents Depth.
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det', 'seg'. 'det', 'multi_modality-det', 'mono-det', 'seg'.
point_color (tuple[float], optional): the color of points. point_color (tuple[float], optional): the color of points.
...@@ -149,7 +156,11 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -149,7 +156,11 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert points is not None assert points is not None
check_type('points', points, np.ndarray) check_type('points', points, np.ndarray)
if self.pcd and vis_task != 'seg': # for now we convert points into depth mode for visualization
if pcd_mode != Coord3DMode.DEPTH:
points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH)
if hasattr(self, 'pcd') and vis_task != 'seg':
self.o3d_vis.remove_geometry(self.pcd) self.o3d_vis.remove_geometry(self.pcd)
# set points size in Open3D # set points size in Open3D
...@@ -173,7 +184,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -173,7 +184,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pcd.colors = o3d.utility.Vector3dVector(points_colors) pcd.colors = o3d.utility.Vector3dVector(points_colors)
self.o3d_vis.add_geometry(pcd) self.o3d_vis.add_geometry(pcd)
self.pcd = pcd self.pcd = pcd
self.points_color = points_color self.points_colors = points_colors
# TODO: assign 3D Box color according to pred / GT labels # TODO: assign 3D Box color according to pred / GT labels
# We draw GT / pred bboxes on the same point cloud scenes # We draw GT / pred bboxes on the same point cloud scenes
...@@ -244,14 +255,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -244,14 +255,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.pcd.colors = o3d.utility.Vector3dVector(self.points_colors) self.pcd.colors = o3d.utility.Vector3dVector(self.points_colors)
self.o3d_vis.update_geometry(self.pcd) self.o3d_vis.update_geometry(self.pcd)
# TODO: set bbox color according to palette
def draw_proj_bboxes_3d(self, def draw_proj_bboxes_3d(self,
bboxes_3d: BaseInstance3DBoxes, bboxes_3d: BaseInstance3DBoxes,
input_meta: dict, input_meta: dict,
bbox_color: Tuple[float], bbox_color: Tuple[float] = 'b',
line_styles: Union[str, List[str]] = '-', line_styles: Union[str, List[str]] = '-',
line_widths: Union[Union[int, float], line_widths: Union[Union[int, float],
List[Union[int, float]]] = 2, List[Union[int, float]]] = 1):
box_mode: str = 'lidar'):
"""Draw projected 3D boxes on the image. """Draw projected 3D boxes on the image.
Args: Args:
...@@ -269,19 +280,18 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -269,19 +280,18 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
the same length with lines or just single value. the same length with lines or just single value.
If ``line_widths`` is single value, all the lines will If ``line_widths`` is single value, all the lines will
have the same linewidth. Defaults to 2. have the same linewidth. Defaults to 2.
box_mode (str): Indicate the coordinates of bbox.
""" """
check_type('bboxes', bboxes_3d, BaseInstance3DBoxes) check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)
if box_mode == 'depth': if isinstance(bboxes_3d, DepthInstance3DBoxes):
proj_bbox3d_to_img = proj_depth_bbox3d_to_img proj_bbox3d_to_img = proj_depth_bbox3d_to_img
elif box_mode == 'lidar': elif isinstance(bboxes_3d, LiDARInstance3DBoxes):
proj_bbox3d_to_img = proj_lidar_bbox3d_to_img proj_bbox3d_to_img = proj_lidar_bbox3d_to_img
elif box_mode == 'camera': elif isinstance(bboxes_3d, CameraInstance3DBoxes):
proj_bbox3d_to_img = proj_camera_bbox3d_to_img proj_bbox3d_to_img = proj_camera_bbox3d_to_img
else: else:
raise NotImplementedError(f'unsupported box mode {box_mode}') raise NotImplementedError('unsupported box type!')
# (num_bboxes_3d, 8, 2) # (num_bboxes_3d, 8, 2)
proj_bboxes_3d = proj_bbox3d_to_img(bboxes_3d, input_meta) proj_bboxes_3d = proj_bbox3d_to_img(bboxes_3d, input_meta)
...@@ -304,7 +314,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -304,7 +314,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.draw_lines(x_datas, y_datas, bbox_color, line_styles, self.draw_lines(x_datas, y_datas, bbox_color, line_styles,
line_widths) line_widths)
def draw_seg_mask(self, seg_mask_colors: np.array, vis_task: str): def draw_seg_mask(self, seg_mask_colors: np.array):
"""Add segmentation mask to visualizer via per-point colorization. """Add segmentation mask to visualizer via per-point colorization.
Args: Args:
...@@ -323,7 +333,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -323,7 +333,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.o3d_vis.add_geometry(mesh_frame) self.o3d_vis.add_geometry(mesh_frame)
seg_points = copy.deepcopy(seg_mask_colors) seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset seg_points[:, 0] += offset
self.set_points(seg_points, vis_task, self.points_size, mode='xyzrgb') self.set_points(seg_points, vis_task='seg', pcd_mode=2, mode='xyzrgb')
def _draw_instances_3d(self, data_input: dict, instances: InstanceData, def _draw_instances_3d(self, data_input: dict, instances: InstanceData,
input_meta: dict, vis_task: str, input_meta: dict, vis_task: str,
...@@ -339,7 +349,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -339,7 +349,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
'det', 'multi_modality-det', 'mono-det'. 'det', 'multi_modality-det', 'mono-det'.
Returns: Returns:
np.ndarray: the drawn image which channel is RGB. dict: the drawn point cloud and image which channel is RGB.
""" """
bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes
...@@ -359,7 +369,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -359,7 +369,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
bboxes_3d_depth = bboxes_3d.clone() bboxes_3d_depth = bboxes_3d.clone()
self.set_points(points, vis_task) self.set_points(points, pcd_mode=2, vis_task=vis_task)
self.draw_bboxes_3d(bboxes_3d_depth) self.draw_bboxes_3d(bboxes_3d_depth)
drawn_bboxes_3d = tensor2ndarray(bboxes_3d_depth.tensor) drawn_bboxes_3d = tensor2ndarray(bboxes_3d_depth.tensor)
...@@ -367,8 +377,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -367,8 +377,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono-det', 'multi_modality-det']:
assert 'img' in data_input assert 'img' in data_input
img = data_input['img']
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy() img = img.permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
self.set_image(img) self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta) self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
...@@ -379,16 +390,30 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -379,16 +390,30 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
return data_3d return data_3d
def _draw_pts_sem_seg(self, def _draw_pts_sem_seg(self,
points: Tensor, points: Union[Tensor, np.ndarray],
pts_seg: PointData, pts_seg: PointData,
vis_task: str,
palette: Optional[List[tuple]] = None, palette: Optional[List[tuple]] = None,
ignore_index: Optional[int] = None): ignore_index: Optional[int] = None):
"""Draw 3D semantic mask of GT or prediction.
Args:
points (Tensor | np.ndarray): The input point
cloud to draw.
pts_seg (:obj:`PointData`): Data structure for
pixel-level annotations or predictions.
palette (List[tuple], optional): Palette information
corresponding to the category. Defaults to None.
ignore_index (int, optional): Ignore category.
Defaults to None.
Returns:
dict: the drawn points with color.
"""
check_type('points', points, (np.ndarray, Tensor)) check_type('points', points, (np.ndarray, Tensor))
points = tensor2ndarray(points) points = tensor2ndarray(points)
pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask) pts_sem_seg = tensor2ndarray(pts_seg.pts_semantic_mask)
palette = np.array(palette)
if ignore_index is not None: if ignore_index is not None:
points = points[pts_sem_seg != ignore_index] points = points[pts_sem_seg != ignore_index]
...@@ -397,8 +422,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -397,8 +422,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pts_color = palette[pts_sem_seg] pts_color = palette[pts_sem_seg]
seg_color = np.concatenate([points[:, :3], pts_color], axis=1) seg_color = np.concatenate([points[:, :3], pts_color], axis=1)
self.set_points(points, vis_task) self.set_points(points, pcd_mode=2, vis_task='seg')
self.draw_seg_mask(seg_color, vis_task) self.draw_seg_mask(seg_color)
seg_data_3d = dict(points=points, seg_color=seg_color) seg_data_3d = dict(points=points, seg_color=seg_color)
return seg_data_3d return seg_data_3d
...@@ -416,7 +441,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -416,7 +441,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
Args: Args:
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det'. 'det', 'multi_modality-det', 'mono-det', 'seg'.
out_file (str): Output file path. out_file (str): Output file path.
drawn_img (np.ndarray, optional): The image to show. If drawn_img drawn_img (np.ndarray, optional): The image to show. If drawn_img
is None, it will show the image got by Visualizer. Defaults is None, it will show the image got by Visualizer. Defaults
...@@ -427,10 +452,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -427,10 +452,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
continue_key (str): The key for users to continue. Defaults to continue_key (str): The key for users to continue. Defaults to
the space key. the space key.
""" """
if vis_task in ['det', 'multi_modality-det']: if vis_task in ['det', 'multi_modality-det', 'seg']:
self.o3d_vis.run() self.o3d_vis.run()
if out_file is not None: if out_file is not None:
self.o3d_vis.capture_screen_image(out_file) self.o3d_vis.capture_screen_image(out_file + '.png')
self.o3d_vis.destroy_window() self.o3d_vis.destroy_window()
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono-det', 'multi_modality-det']:
...@@ -439,6 +464,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -439,6 +464,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if drawn_img is not None: if drawn_img is not None:
super().show(drawn_img, win_name, wait_time, continue_key) super().show(drawn_img, win_name, wait_time, continue_key)
# TODO: Support Visualize the 3D results from image and point cloud
# respectively
@master_only @master_only
def add_datasample(self, def add_datasample(self,
name: str, name: str,
...@@ -490,6 +517,13 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -490,6 +517,13 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
palette = self.dataset_meta.get('PALETTE', None) palette = self.dataset_meta.get('PALETTE', None)
ignore_index = self.dataset_meta.get('ignore_index', None) ignore_index = self.dataset_meta.get('ignore_index', None)
gt_data_3d = None
pred_data_3d = None
gt_seg_data_3d = None
pred_seg_data_3d = None
gt_img_data = None
pred_img_data = None
if draw_gt and gt_sample is not None: if draw_gt and gt_sample is not None:
if 'gt_instances_3d' in gt_sample: if 'gt_instances_3d' in gt_sample:
gt_data_3d = self._draw_instances_3d(data_input, gt_data_3d = self._draw_instances_3d(data_input,
...@@ -503,7 +537,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -503,7 +537,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
gt_img_data = self._draw_instances(img, gt_sample.gt_instances, gt_img_data = self._draw_instances(img, gt_sample.gt_instances,
classes, palette) classes, palette)
if 'gt_pts_sem_seg' in gt_sample: if 'gt_pts_seg' in gt_sample:
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
...@@ -511,30 +545,31 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -511,30 +545,31 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input assert 'points' in data_input
gt_seg_data_3d = \ gt_seg_data_3d = \
self._draw_pts_sem_seg(data_input['points'], self._draw_pts_sem_seg(data_input['points'],
gt_sample.gt_pts_seg, pred_sample.pred_pts_seg,
classes, vis_task, palette, palette, ignore_index)
out_file, ignore_index)
if draw_pred and pred_sample is not None: if draw_pred and pred_sample is not None:
if 'pred_instances_3d' in pred_sample: if 'pred_instances_3d' in pred_sample:
pred_instances_3d = pred_sample.pred_instances_3d pred_instances_3d = pred_sample.pred_instances_3d
# .cpu can not be used for BaseInstancesBoxes3D
# so we need to use .to('cpu')
pred_instances_3d = pred_instances_3d[ pred_instances_3d = pred_instances_3d[
pred_instances_3d.scores_3d > pred_score_thr].cpu() pred_instances_3d.scores_3d > pred_score_thr].to('cpu')
pred_data_3d = self._draw_instances_3d(data_input, pred_data_3d = self._draw_instances_3d(data_input,
pred_instances_3d, pred_instances_3d,
pred_sample.metainfo, pred_sample.metainfo,
vis_task, palette) vis_task, palette)
if 'pred_instances' in pred_sample: if 'pred_instances' in pred_sample:
assert 'img' in data_input if 'img' in data_input and len(pred_sample.pred_instances) > 0:
pred_instances = pred_sample.pred_instances pred_instances = pred_sample.pred_instances
pred_instances = pred_instances_3d[ pred_instances = pred_instances_3d[
pred_instances.scores > pred_score_thr].cpu() pred_instances.scores > pred_score_thr].cpu()
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy() img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
pred_img_data = self._draw_instances(img, pred_instances, pred_img_data = self._draw_instances(
classes, palette) img, pred_instances, classes, palette)
if 'pred_pts_sem_seg' in pred_sample: if 'pred_pts_seg' in pred_sample:
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
...@@ -543,8 +578,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -543,8 +578,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pred_seg_data_3d = \ pred_seg_data_3d = \
self._draw_pts_sem_seg(data_input['points'], self._draw_pts_sem_seg(data_input['points'],
pred_sample.pred_pts_seg, pred_sample.pred_pts_seg,
classes, palette, out_file, palette, ignore_index)
ignore_index)
# monocular 3d object detection image # monocular 3d object detection image
if gt_data_3d is not None and pred_data_3d is not None: if gt_data_3d is not None and pred_data_3d is not None:
...@@ -578,9 +612,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -578,9 +612,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if out_file is not None: if out_file is not None:
if drawn_img_3d is not None: if drawn_img_3d is not None:
mmcv.imwrite(drawn_img_3d[..., ::-1], out_file) mmcv.imwrite(drawn_img_3d[..., ::-1], out_file + '.jpg')
if drawn_img is not None: if drawn_img is not None:
mmcv.imwrite(drawn_img[..., ::-1], out_file) mmcv.imwrite(drawn_img[..., ::-1], out_file + '.jpg')
if gt_data_3d is not None: if gt_data_3d is not None:
write_obj(gt_data_3d['points'], write_obj(gt_data_3d['points'],
osp.join(out_file, 'points.obj')) osp.join(out_file, 'points.obj'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment