Unverified Commit 7d5c5a33 authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] fix some mono3d related bugs (#1816)

* fix mono3d related bugs

* update kitti-mono3d script

* update mono3d task

* update resize3d and randomresize3d

* fix

* update dataset converter script

* fix part of comments

* unify the task name in datasets and visualization

* fix comments

* rename 3d to lidar_det

* fix ci

* change boxlist to boxtype

* change default value ot lidar_det

* fix bugs
parent 4b73569e
...@@ -35,6 +35,10 @@ test_pipeline = [ ...@@ -35,6 +35,10 @@ test_pipeline = [
dict(type='Resize', scale=(1242, 375), keep_ratio=True), dict(type='Resize', scale=(1242, 375), keep_ratio=True),
dict(type='Pack3DDetInputs', keys=['img']) dict(type='Pack3DDetInputs', keys=['img'])
] ]
eval_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(type='Pack3DDetInputs', keys=['img'])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=2, batch_size=2,
......
...@@ -65,7 +65,7 @@ train_dataloader = dict( ...@@ -65,7 +65,7 @@ train_dataloader = dict(
CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT',
CAM_BACK_LEFT='samples/CAM_BACK_LEFT'), CAM_BACK_LEFT='samples/CAM_BACK_LEFT'),
ann_file='nuscenes_infos_train.pkl', ann_file='nuscenes_infos_train.pkl',
task='mono3d', task='mono_det',
pipeline=train_pipeline, pipeline=train_pipeline,
metainfo=metainfo, metainfo=metainfo,
modality=input_modality, modality=input_modality,
...@@ -92,7 +92,7 @@ val_dataloader = dict( ...@@ -92,7 +92,7 @@ val_dataloader = dict(
CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT',
CAM_BACK_LEFT='samples/CAM_BACK_LEFT'), CAM_BACK_LEFT='samples/CAM_BACK_LEFT'),
ann_file='nuscenes_infos_val.pkl', ann_file='nuscenes_infos_val.pkl',
task='mono3d', task='mono_det',
pipeline=test_pipeline, pipeline=test_pipeline,
modality=input_modality, modality=input_modality,
metainfo=metainfo, metainfo=metainfo,
......
...@@ -113,7 +113,7 @@ val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) ...@@ -113,7 +113,7 @@ val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# optimizer # optimizer
optim_wrapper = dict( optim_wrapper = dict(
optimizer=dict(lr=0.01), optimizer=dict(lr=0.001),
paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.), paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.),
clip_grad=dict(max_norm=35, norm_type=2)) clip_grad=dict(max_norm=35, norm_type=2))
...@@ -134,4 +134,4 @@ param_scheduler = [ ...@@ -134,4 +134,4 @@ param_scheduler = [
gamma=0.1) gamma=0.1)
] ]
train_cfg = dict(max_epochs=48) train_cfg = dict(max_epochs=48, val_interval=2)
...@@ -47,8 +47,10 @@ train_dataloader = dict( ...@@ -47,8 +47,10 @@ train_dataloader = dict(
test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# training schedule for 1x # training schedule for 6x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) max_epochs = 72
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')
...@@ -57,9 +59,9 @@ param_scheduler = [ ...@@ -57,9 +59,9 @@ param_scheduler = [
dict( dict(
type='MultiStepLR', type='MultiStepLR',
begin=0, begin=0,
end=12, end=max_epochs,
by_epoch=True, by_epoch=True,
milestones=[8, 11], milestones=[50],
gamma=0.1) gamma=0.1)
] ]
...@@ -68,3 +70,5 @@ optim_wrapper = dict( ...@@ -68,3 +70,5 @@ optim_wrapper = dict(
type='OptimWrapper', type='OptimWrapper',
optimizer=dict(type='Adam', lr=2.5e-4), optimizer=dict(type='Adam', lr=2.5e-4),
clip_grad=None) clip_grad=None)
find_unused_parameters = True
...@@ -66,7 +66,7 @@ def main(args): ...@@ -66,7 +66,7 @@ def main(args):
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='mono-det') vis_task='mono_det')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -67,7 +67,7 @@ def main(args): ...@@ -67,7 +67,7 @@ def main(args):
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='multi_modality-det') vis_task='multi-modality_det')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -56,7 +56,7 @@ def main(args): ...@@ -56,7 +56,7 @@ def main(args):
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
pred_score_thr=args.score_thr, pred_score_thr=args.score_thr,
vis_task='det') vis_task='lidar_det')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -51,7 +51,7 @@ def main(args): ...@@ -51,7 +51,7 @@ def main(args):
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
vis_task='seg') vis_task='lidar_seg')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -102,10 +102,10 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py - ...@@ -102,10 +102,10 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py -
python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --aug --output-dir ${OUTPUT_DIR} --online python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --aug --output-dir ${OUTPUT_DIR} --online
``` ```
如果您还想显示 2D 图像以及投影的 3D 边界框,则需要找到支持多模态数据加载的配置文件,然后将 `--task` 参数更改为 `multi_modality-det`。一个例子如下所示 如果您还想显示 2D 图像以及投影的 3D 边界框,则需要找到支持多模态数据加载的配置文件,然后将 `--task` 参数更改为 `multi-modality_det`。一个例子如下所示
```shell ```shell
python tools/misc/browse_dataset.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py --task multi_modality-det --output-dir ${OUTPUT_DIR} --online python tools/misc/browse_dataset.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py --task multi-modality_det --output-dir ${OUTPUT_DIR} --online
``` ```
![](../../resources/browse_dataset_multi_modality.png) ![](../../resources/browse_dataset_multi_modality.png)
...@@ -121,7 +121,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --tas ...@@ -121,7 +121,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --tas
在单目 3D 检测任务中浏览 nuScenes 数据集 在单目 3D 检测任务中浏览 nuScenes 数据集
```shell ```shell
python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono-det --output-dir ${OUTPUT_DIR} --online python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono_det --output-dir ${OUTPUT_DIR} --online
``` ```
![](../../resources/browse_dataset_mono.png) ![](../../resources/browse_dataset_mono.png)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS, PIPELINES, build_dataset from .builder import DATASETS, PIPELINES, build_dataset
from .convert_utils import get_2d_boxes
from .dataset_wrappers import CBGSDataset from .dataset_wrappers import CBGSDataset
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
...@@ -22,8 +21,8 @@ from .transforms import (AffineResize, BackgroundPointsFilter, GlobalAlignment, ...@@ -22,8 +21,8 @@ from .transforms import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
ObjectNameFilter, ObjectNoise, ObjectRangeFilter, ObjectNameFilter, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointSample, PointShuffle, ObjectSample, PointSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomShiftScale, RandomFlip3D, RandomJitterPoints, RandomResize3D,
VoxelBasedPointSampler) RandomShiftScale, Resize3D, VoxelBasedPointSampler)
from .utils import get_loading_pipeline from .utils import get_loading_pipeline
from .waymo_dataset import WaymoDataset from .waymo_dataset import WaymoDataset
...@@ -40,5 +39,6 @@ __all__ = [ ...@@ -40,5 +39,6 @@ __all__ = [
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor',
'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize', 'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'PIPELINES', 'get_2d_boxes' 'RandomShiftScale', 'LoadPointsFromDict', 'PIPELINES',
'Resize3D', 'RandomResize3D',
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from nuscenes.utils.geometry_utils import view_points from nuscenes.utils.geometry_utils import view_points
...@@ -11,6 +11,11 @@ from shapely.geometry import MultiPoint, box ...@@ -11,6 +11,11 @@ from shapely.geometry import MultiPoint, box
from mmdet3d.structures import Box3DMode, CameraInstance3DBoxes, points_cam2img from mmdet3d.structures import Box3DMode, CameraInstance3DBoxes, points_cam2img
from mmdet3d.structures.ops import box_np_ops from mmdet3d.structures.ops import box_np_ops
kitti_categories = ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc')
waymo_categories = ('Car', 'Pedestrian', 'Cyclist')
nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier') 'barrier')
...@@ -48,8 +53,10 @@ LyftNameMapping = { ...@@ -48,8 +53,10 @@ LyftNameMapping = {
} }
def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): def get_nuscenes_2d_boxes(nusc, sample_data_token: str,
"""Get the 2D annotation records for a given `sample_data_token`. visibilities: List[str]):
"""Get the 2d / mono3d annotation records for a given `sample_data_token of
nuscenes dataset.
Args: Args:
sample_data_token (str): Sample data token belonging to a camera sample_data_token (str): Sample data token belonging to a camera
...@@ -57,7 +64,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): ...@@ -57,7 +64,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]):
visibilities (list[str]): Visibility filter. visibilities (list[str]): Visibility filter.
Return: Return:
list[dict]: List of 2D annotation record that belongs to the input list[dict]: List of 2d annotation record that belongs to the input
`sample_data_token`. `sample_data_token`.
""" """
...@@ -128,7 +135,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): ...@@ -128,7 +135,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]):
# Generate dictionary record to be included in the .json file. # Generate dictionary record to be included in the .json file.
repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y, repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y,
sample_data_token, sd_rec['filename']) 'nuscenes')
# if repro_rec is None, we do not append it into repre_recs # if repro_rec is None, we do not append it into repre_recs
if repro_rec is not None: if repro_rec is not None:
...@@ -178,23 +185,36 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): ...@@ -178,23 +185,36 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]):
return repro_recs return repro_recs
def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): def get_kitti_style_2d_boxes(info: dict,
"""Get the 2D annotation records for a given info. cam_idx: int = 2,
occluded: Tuple[int] = (0, 1, 2, 3),
annos: Optional[dict] = None,
mono3d: bool = True,
dataset: str = 'kitti'):
"""Get the 2d / mono3d annotation records for a given info.
This function is used to get 2D annotations when loading annotations from This function is used to get 2D/Mono3D annotations when loading annotations
a dataset class. The original version in the data converter will be from a kitti-style dataset class, such as KITTI and Waymo dataset.
deprecated in the future.
Args: Args:
info: Information of the given sample data. info (dict): Information of the given sample data.
occluded: Integer (0, 1, 2, 3) indicating occlusion state: cam_idx (int): Camera id which the 2d / mono3d annotations to obtain
belong to. In KITTI, typically only CAM 2 will be used,
and in Waymo, multi cameras could be used.
Defaults to 2.
occluded (tuple[int]): Integer (0, 1, 2, 3) indicating occlusion state:
0 = fully visible, 1 = partly occluded, 2 = largely occluded, 0 = fully visible, 1 = partly occluded, 2 = largely occluded,
3 = unknown, -1 = DontCare 3 = unknown, -1 = DontCare.
Defaults to (0, 1, 2, 3).
annos (dict, optional): Original annotations.
mono3d (bool): Whether to get boxes with mono3d annotation. mono3d (bool): Whether to get boxes with mono3d annotation.
Defaults to True.
dataset (str): Dataset name of getting 2d bboxes.
Defaults to `kitti`.
Return: Return:
list[dict]: List of 2D annotation record that belongs to the input list[dict]: List of 2d / mono3d annotation record that
`sample_data_token`. belongs to the input camera id.
""" """
# Get calibration information # Get calibration information
camera_intrinsic = info['calib'][f'P{cam_idx}'] camera_intrinsic = info['calib'][f'P{cam_idx}']
...@@ -224,7 +244,6 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): ...@@ -224,7 +244,6 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True):
ann_rec['sample_annotation_token'] = \ ann_rec['sample_annotation_token'] = \
f"{info['image']['image_idx']}.{ann_idx}" f"{info['image']['image_idx']}.{ann_idx}"
ann_rec['sample_data_token'] = info['image']['image_idx'] ann_rec['sample_data_token'] = info['image']['image_idx']
sample_data_token = info['image']['image_idx']
loc = ann_rec['location'][np.newaxis, :] loc = ann_rec['location'][np.newaxis, :]
dim = ann_rec['dimensions'][np.newaxis, :] dim = ann_rec['dimensions'][np.newaxis, :]
...@@ -266,9 +285,8 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): ...@@ -266,9 +285,8 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True):
min_x, min_y, max_x, max_y = final_coords min_x, min_y, max_x, max_y = final_coords
# Generate dictionary record to be included in the .json file. # Generate dictionary record to be included in the .json file.
repro_rec = generate_waymo_mono3d_record(ann_rec, min_x, min_y, max_x, repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y,
max_y, sample_data_token, dataset)
info['image']['image_path'])
# If mono3d=True, add 3D annotations in camera coordinates # If mono3d=True, add 3D annotations in camera coordinates
if mono3d and (repro_rec is not None): if mono3d and (repro_rec is not None):
...@@ -288,11 +306,7 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): ...@@ -288,11 +306,7 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True):
# samples with depth < 0 will be removed # samples with depth < 0 will be removed
if repro_rec['depth'] <= 0: if repro_rec['depth'] <= 0:
continue continue
repro_recs.append(repro_rec)
repro_rec['attribute_name'] = -1 # no attribute in KITTI
repro_rec['attribute_id'] = -1
repro_recs.append(repro_rec)
return repro_recs return repro_recs
...@@ -355,120 +369,50 @@ def post_process_coords( ...@@ -355,120 +369,50 @@ def post_process_coords(
def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float, def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
sample_data_token: str, filename: str) -> OrderedDict: dataset: str) -> OrderedDict:
"""Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates.
Args:
ann_rec (dict): Original 3d annotation record.
x1 (float): Minimum value of the x coordinate.
y1 (float): Minimum value of the y coordinate.
x2 (float): Maximum value of the x coordinate.
y2 (float): Maximum value of the y coordinate.
sample_data_token (str): Sample data token.
filename (str):The corresponding image file where the annotation
is present.
Returns:
dict: A sample mono3D annotation record.
- bbox_label (int): 2d box label id
- bbox_label_3d (int): 3d box label id
- bbox (list[float]): left x, top y, right x, bottom y
of 2d box
- bbox_3d_isvalid (bool): whether the box is valid
"""
repro_rec = OrderedDict()
repro_rec['sample_data_token'] = sample_data_token
coco_rec = dict()
relevant_keys = [
'attribute_tokens',
'category_name',
'instance_token',
'next',
'num_lidar_pts',
'num_radar_pts',
'prev',
'sample_annotation_token',
'sample_data_token',
'visibility_token',
]
for key, value in ann_rec.items():
if key in relevant_keys:
repro_rec[key] = value
repro_rec['bbox_corners'] = [x1, y1, x2, y2]
repro_rec['filename'] = filename
if repro_rec['category_name'] not in NuScenesNameMapping:
return None
cat_name = NuScenesNameMapping[repro_rec['category_name']]
coco_rec['bbox_label'] = nus_categories.index(cat_name)
coco_rec['bbox_label_3d'] = nus_categories.index(cat_name)
coco_rec['bbox'] = [x1, y1, x2, y2]
coco_rec['bbox_3d_isvalid'] = True
return coco_rec
def generate_waymo_mono3d_record(ann_rec, x1, y1, x2, y2, sample_data_token,
filename):
"""Generate one 2D annotation record given various information on top of """Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates. the 2D bounding box coordinates.
The original version in the data converter will be deprecated in the
future.
Args: Args:
ann_rec (dict): Original 3d annotation record. ann_rec (dict): Original 3d annotation record.
x1 (float): Minimum value of the x coordinate. x1 (float): Minimum value of the x coordinate.
y1 (float): Minimum value of the y coordinate. y1 (float): Minimum value of the y coordinate.
x2 (float): Maximum value of the x coordinate. x2 (float): Maximum value of the x coordinate.
y2 (float): Maximum value of the y coordinate. y2 (float): Maximum value of the y coordinate.
sample_data_token (str): Sample data token. dataset (str): Name of dataset.
filename (str):The corresponding image file where the annotation
is present.
Returns: Returns:
dict: A sample 2D annotation record. dict: A sample 2d annotation record.
- file_name (str): file name - bbox_label (int): 2d box label id
- image_id (str): sample data token - bbox_label_3d (int): 3d box label id
- area (float): 2d box area - bbox (list[float]): left x, top y, right x, bottom y
- category_name (str): category name of 2d box
- category_id (int): category id - bbox_3d_isvalid (bool): whether the box is valid
- bbox (list[float]): left x, top y, x_size, y_size of 2d box
- iscrowd (int): whether the area is crowd
""" """
kitti_categories = ('Car', 'Pedestrian', 'Cyclist')
repro_rec = OrderedDict()
repro_rec['sample_data_token'] = sample_data_token
coco_rec = dict()
key_mapping = {
'name': 'category_name',
'num_points_in_gt': 'num_lidar_pts',
'sample_annotation_token': 'sample_annotation_token',
'sample_data_token': 'sample_data_token',
}
for key, value in ann_rec.items():
if key in key_mapping.keys():
repro_rec[key_mapping[key]] = value
repro_rec['bbox_corners'] = [x1, y1, x2, y2] if dataset == 'nuscenes':
repro_rec['filename'] = filename cat_name = ann_rec['category_name']
if cat_name not in NuScenesNameMapping:
return None
else:
cat_name = NuScenesNameMapping[cat_name]
categories = nus_categories
else:
cat_name = ann_rec['name']
if cat_name not in categories:
return None
if dataset == 'kitti':
categories = kitti_categories
elif dataset == 'waymo':
categories = waymo_categories
else:
raise NotImplementedError('Unsupported dataset!')
coco_rec['image_id'] = sample_data_token rec = dict()
coco_rec['area'] = (y2 - y1) * (x2 - x1) rec['bbox_label'] = categories.index(cat_name)
rec['bbox_label_3d'] = rec['bbox_label']
rec['bbox'] = [x1, y1, x2, y2]
rec['bbox_3d_isvalid'] = True
if repro_rec['category_name'] not in kitti_categories: return rec
return None
cat_name = repro_rec['category_name']
coco_rec['category_id'] = kitti_categories.index(cat_name)
coco_rec['bbox_label'] = coco_rec['category_id']
coco_rec['bbox_label_3d'] = coco_rec['bbox_label']
coco_rec['bbox'] = [x1, y1, x2 - x1, y2 - y1]
coco_rec['iscrowd'] = 0
return coco_rec
...@@ -52,6 +52,7 @@ class KittiDataset(Det3DDataset): ...@@ -52,6 +52,7 @@ class KittiDataset(Det3DDataset):
pipeline: List[Union[dict, Callable]] = [], pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True), modality: dict = dict(use_lidar=True),
default_cam_key: str = 'CAM2', default_cam_key: str = 'CAM2',
task: str = 'lidar_det',
box_type_3d: str = 'LiDAR', box_type_3d: str = 'LiDAR',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
...@@ -59,6 +60,8 @@ class KittiDataset(Det3DDataset): ...@@ -59,6 +60,8 @@ class KittiDataset(Det3DDataset):
**kwargs) -> None: **kwargs) -> None:
self.pcd_limit_range = pcd_limit_range self.pcd_limit_range = pcd_limit_range
assert task in ('lidar_det', 'mono_det')
self.task = task
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_file, ann_file=ann_file,
...@@ -108,6 +111,9 @@ class KittiDataset(Det3DDataset): ...@@ -108,6 +111,9 @@ class KittiDataset(Det3DDataset):
info['plane'] = plane_lidar info['plane'] = plane_lidar
if self.task == 'mono_det':
info['instances'] = info['cam_instances'][self.default_cam_key]
info = super().parse_data_info(info) info = super().parse_data_info(info)
return info return info
...@@ -136,6 +142,12 @@ class KittiDataset(Det3DDataset): ...@@ -136,6 +142,12 @@ class KittiDataset(Det3DDataset):
ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32) ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64) ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
if self.task == 'mono_det':
ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
ann_info['depths'] = np.zeros((0), dtype=np.float32)
ann_info = self._remove_dontcare(ann_info) ann_info = self._remove_dontcare(ann_info)
# in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam # in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam
lidar2cam = np.array(info['images']['CAM2']['lidar2cam']) lidar2cam = np.array(info['images']['CAM2']['lidar2cam'])
......
...@@ -22,7 +22,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -22,7 +22,7 @@ class NuScenesDataset(Det3DDataset):
Args: Args:
data_root (str): Path of dataset root. data_root (str): Path of dataset root.
ann_file (str): Path of annotation file. ann_file (str): Path of annotation file.
task (str, optional): Detection task. Defaults to '3d'. task (str, optional): Detection task. Defaults to 'lidar_det'.
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
box_type_3d (str): Type of 3D box of this dataset. box_type_3d (str): Type of 3D box of this dataset.
...@@ -56,7 +56,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -56,7 +56,7 @@ class NuScenesDataset(Det3DDataset):
def __init__(self, def __init__(self,
data_root: str, data_root: str,
ann_file: str, ann_file: str,
task: str = '3d', task: str = 'lidar_det',
pipeline: List[Union[dict, Callable]] = [], pipeline: List[Union[dict, Callable]] = [],
box_type_3d: str = 'LiDAR', box_type_3d: str = 'LiDAR',
modality: dict = dict( modality: dict = dict(
...@@ -72,7 +72,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -72,7 +72,7 @@ class NuScenesDataset(Det3DDataset):
self.with_velocity = with_velocity self.with_velocity = with_velocity
# TODO: Redesign multi-view data process in the future # TODO: Redesign multi-view data process in the future
assert task in ('3d', 'mono3d', 'multi-view') assert task in ('lidar_det', 'mono_det', 'multi-view_det')
self.task = task self.task = task
assert box_type_3d.lower() in ('lidar', 'camera') assert box_type_3d.lower() in ('lidar', 'camera')
...@@ -152,7 +152,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -152,7 +152,7 @@ class NuScenesDataset(Det3DDataset):
# the nuscenes box center is [0.5, 0.5, 0.5], we change it to be # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
# the same as KITTI (0.5, 0.5, 0) # the same as KITTI (0.5, 0.5, 0)
# TODO: Unify the coordinates # TODO: Unify the coordinates
if self.task == 'mono3d': if self.task == 'mono_det':
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(
ann_info['gt_bboxes_3d'], ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
...@@ -180,7 +180,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -180,7 +180,7 @@ class NuScenesDataset(Det3DDataset):
dict: Has `ann_info` in training stage. And dict: Has `ann_info` in training stage. And
all path has been converted to absolute path. all path has been converted to absolute path.
""" """
if self.task == 'mono3d': if self.task == 'mono_det':
data_list = [] data_list = []
if self.modality['use_lidar']: if self.modality['use_lidar']:
info['lidar_points']['lidar_path'] = \ info['lidar_points']['lidar_path'] = \
......
...@@ -14,8 +14,8 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter, ...@@ -14,8 +14,8 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter,
ObjectNameFilter, ObjectNoise, ObjectRangeFilter, ObjectNameFilter, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointSample, PointShuffle, ObjectSample, PointSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomShiftScale, RandomFlip3D, RandomJitterPoints, RandomResize3D,
VoxelBasedPointSampler) RandomShiftScale, Resize3D, VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -29,5 +29,5 @@ __all__ = [ ...@@ -29,5 +29,5 @@ __all__ = [
'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample', 'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints', 'AffineResize', 'RandomShiftScale', 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale',
'LoadPointsFromDict' 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D'
] ]
...@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
from mmcv.transforms import BaseTransform from mmcv.transforms import BaseTransform, RandomResize, Resize
from mmengine import is_tuple_of from mmengine import is_tuple_of
from mmdet3d.models.task_modules import VoxelGenerator from mmdet3d.models.task_modules import VoxelGenerator
...@@ -163,9 +163,7 @@ class RandomFlip3D(RandomFlip): ...@@ -163,9 +163,7 @@ class RandomFlip3D(RandomFlip):
if 'centers_2d' in input_dict: if 'centers_2d' in input_dict:
assert self.sync_2d is True and direction == 'horizontal', \ assert self.sync_2d is True and direction == 'horizontal', \
'Only support sync_2d=True and horizontal flip with images' 'Only support sync_2d=True and horizontal flip with images'
# TODO fix this ori_shape and other keys in vision based model w = input_dict['img_shape'][1]
# TODO ori_shape to img_shape
w = input_dict['ori_shape'][1]
input_dict['centers_2d'][..., 0] = \ input_dict['centers_2d'][..., 0] = \
w - input_dict['centers_2d'][..., 0] w - input_dict['centers_2d'][..., 0]
# need to modify the horizontal position of camera center # need to modify the horizontal position of camera center
...@@ -1671,8 +1669,9 @@ class AffineResize(BaseTransform): ...@@ -1671,8 +1669,9 @@ class AffineResize(BaseTransform):
if 'gt_bboxes' in results: if 'gt_bboxes' in results:
results['gt_bboxes'] = results['gt_bboxes'][valid_index] results['gt_bboxes'] = results['gt_bboxes'][valid_index]
if 'gt_labels' in results: if 'gt_bboxes_labels' in results:
results['gt_labels'] = results['gt_labels'][valid_index] results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
valid_index]
if 'gt_masks' in results: if 'gt_masks' in results:
raise NotImplementedError( raise NotImplementedError(
'AffineResize only supports bbox.') 'AffineResize only supports bbox.')
...@@ -1842,3 +1841,71 @@ class RandomShiftScale(BaseTransform): ...@@ -1842,3 +1841,71 @@ class RandomShiftScale(BaseTransform):
repr_str += f'(shift_scale={self.shift_scale}, ' repr_str += f'(shift_scale={self.shift_scale}, '
repr_str += f'aug_prob={self.aug_prob}) ' repr_str += f'aug_prob={self.aug_prob}) '
return repr_str return repr_str
@TRANSFORMS.register_module()
class Resize3D(Resize):
def _resize_3d(self, results):
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
if 'centers_2d' in results:
results['centers_2d'] *= results['scale_factor'][:2]
results['cam2img'][0] *= np.array(results['scale_factor'][0])
results['cam2img'][1] *= np.array(results['scale_factor'][1])
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes, semantic
segmentation map and keypoints.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'img_shape',
and 'keep_ratio' keys are updated in result dict.
"""
super(Resize3D, self).transform(results)
self._resize_3d(results)
return results
@TRANSFORMS.register_module()
class RandomResize3D(RandomResize):
"""The difference between RandomResize3D and RandomResize:
1. Compared to RandomResize, this class would further
check if scale is already set in results.
2. During resizing, this class would modify the centers_2d
and cam2img with ``results['scale']``.
"""
def _resize_3d(self, results):
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
if 'centers_2d' in results:
results['centers_2d'] *= results['scale_factor'][:2]
results['cam2img'][0] *= np.array(results['scale_factor'][0])
results['cam2img'][1] *= np.array(results['scale_factor'][1])
def transform(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Compared to RandomResize, this function would further
check if scale is already set in results.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
results['scale'] = self._random_scale()
self.resize.scale = results['scale']
results = self.resize(results)
self._resize_3d(results)
return results
...@@ -83,7 +83,7 @@ class WaymoDataset(KittiDataset): ...@@ -83,7 +83,7 @@ class WaymoDataset(KittiDataset):
pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0], pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0],
cam_sync_instances=False, cam_sync_instances=False,
load_interval=1, load_interval=1,
task='lidar', task='lidar_det',
max_sweeps=0, max_sweeps=0,
**kwargs): **kwargs):
self.load_interval = load_interval self.load_interval = load_interval
...@@ -151,7 +151,7 @@ class WaymoDataset(KittiDataset): ...@@ -151,7 +151,7 @@ class WaymoDataset(KittiDataset):
centers_2d = np.zeros((0, 2), dtype=np.float32) centers_2d = np.zeros((0, 2), dtype=np.float32)
depths = np.zeros((0), dtype=np.float32) depths = np.zeros((0), dtype=np.float32)
if self.task == 'mono3d': if self.task == 'mono_det':
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(
ann_info['gt_bboxes_3d'], ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
...@@ -184,7 +184,7 @@ class WaymoDataset(KittiDataset): ...@@ -184,7 +184,7 @@ class WaymoDataset(KittiDataset):
def parse_data_info(self, info: dict) -> dict: def parse_data_info(self, info: dict) -> dict:
"""if task is lidar or multiview det, use super() method elif task is """if task is lidar or multiview det, use super() method elif task is
mono3d, split the info from frame-wise to img-wise.""" mono3d, split the info from frame-wise to img-wise."""
if self.task != 'mono3d': if self.task != 'mono_det':
if self.cam_sync_instances: if self.cam_sync_instances:
# use the cam sync labels # use the cam sync labels
info['instances'] = info['cam_sync_instances'] info['instances'] = info['cam_sync_instances']
......
...@@ -342,13 +342,13 @@ class WaymoMetric(KittiMetric): ...@@ -342,13 +342,13 @@ class WaymoMetric(KittiMetric):
image_shape = (info['images'][self.default_cam_key]['height'], image_shape = (info['images'][self.default_cam_key]['height'],
info['images'][self.default_cam_key]['width']) info['images'][self.default_cam_key]['width'])
if self.task == 'mono3d': if self.task == 'mono_det':
if idx % self.num_cams == 0: if idx % self.num_cams == 0:
box_dict_per_frame = [] box_dict_per_frame = []
cam0_idx = idx cam0_idx = idx
box_dict = self.convert_valid_bboxes(pred_dicts, info) box_dict = self.convert_valid_bboxes(pred_dicts, info)
if self.task == 'mono3d': if self.task == 'mono_det':
box_dict_per_frame.append(box_dict) box_dict_per_frame.append(box_dict)
if (idx + 1) % self.num_cams != 0: if (idx + 1) % self.num_cams != 0:
continue continue
......
...@@ -133,7 +133,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -133,7 +133,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
def set_points(self, def set_points(self,
points: np.ndarray, points: np.ndarray,
pcd_mode: int = 0, pcd_mode: int = 0,
vis_task: str = 'det', vis_task: str = 'lidar_det',
points_color: Tuple = (0.5, 0.5, 0.5), points_color: Tuple = (0.5, 0.5, 0.5),
points_size: int = 2, points_size: int = 2,
mode: str = 'xyz') -> None: mode: str = 'xyz') -> None:
...@@ -146,7 +146,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -146,7 +146,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
0 represents LiDAR, 1 represents CAMERA, 2 0 represents LiDAR, 1 represents CAMERA, 2
represents Depth. represents Depth.
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det', 'seg'. 'lidar_det', 'multi-modality_det', 'mono_det', 'lidar_seg'.
point_color (tuple[float], optional): the color of points. point_color (tuple[float], optional): the color of points.
Default: (0.5, 0.5, 0.5). Default: (0.5, 0.5, 0.5).
points_size (int, optional): the size of points to show points_size (int, optional): the size of points to show
...@@ -161,7 +161,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -161,7 +161,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if pcd_mode != Coord3DMode.DEPTH: if pcd_mode != Coord3DMode.DEPTH:
points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH) points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH)
if hasattr(self, 'pcd') and vis_task != 'seg': if hasattr(self, 'pcd') and vis_task != 'lidar_seg':
self.o3d_vis.remove_geometry(self.pcd) self.o3d_vis.remove_geometry(self.pcd)
# set points size in Open3D # set points size in Open3D
...@@ -334,7 +334,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -334,7 +334,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.o3d_vis.add_geometry(mesh_frame) self.o3d_vis.add_geometry(mesh_frame)
seg_points = copy.deepcopy(seg_mask_colors) seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset seg_points[:, 0] += offset
self.set_points(seg_points, vis_task='seg', pcd_mode=2, mode='xyzrgb') self.set_points(
seg_points, vis_task='lidar_seg', pcd_mode=2, mode='xyzrgb')
def _draw_instances_3d(self, data_input: dict, instances: InstanceData, def _draw_instances_3d(self, data_input: dict, instances: InstanceData,
input_meta: dict, vis_task: str, input_meta: dict, vis_task: str,
...@@ -347,7 +348,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -347,7 +348,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
instance-level annotations or predictions. instance-level annotations or predictions.
metainfo (dict): Meta information. metainfo (dict): Meta information.
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det'. 'lidar_det', 'multi-modality_det', 'mono_det'.
Returns: Returns:
dict: the drawn point cloud and image which channel is RGB. dict: the drawn point cloud and image which channel is RGB.
...@@ -357,7 +358,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -357,7 +358,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_3d = dict() data_3d = dict()
if vis_task in ['det', 'multi_modality-det']: if vis_task in ['lidar_det', 'multi-modality_det']:
assert 'points' in data_input assert 'points' in data_input
points = data_input['points'] points = data_input['points']
check_type('points', points, (np.ndarray, Tensor)) check_type('points', points, (np.ndarray, Tensor))
...@@ -374,7 +375,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -374,7 +375,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor) data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
data_3d['points'] = points data_3d['points'] = points
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono_det', 'multi-modality_det']:
assert 'img' in data_input assert 'img' in data_input
img = data_input['img'] img = data_input['img']
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
...@@ -382,6 +383,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -382,6 +383,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
self.set_image(img) self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta) self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'):
centers_2d = instances.centers_2d
self.draw_points(centers_2d)
drawn_img = self.get_image() drawn_img = self.get_image()
data_3d['img'] = drawn_img data_3d['img'] = drawn_img
...@@ -420,7 +424,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -420,7 +424,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pts_color = palette[pts_sem_seg] pts_color = palette[pts_sem_seg]
seg_color = np.concatenate([points[:, :3], pts_color], axis=1) seg_color = np.concatenate([points[:, :3], pts_color], axis=1)
self.set_points(points, pcd_mode=2, vis_task='seg') self.set_points(points, pcd_mode=2, vis_task='lidar_seg')
self.draw_seg_mask(seg_color) self.draw_seg_mask(seg_color)
seg_data_3d = dict(points=points, seg_color=seg_color) seg_data_3d = dict(points=points, seg_color=seg_color)
...@@ -439,7 +443,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -439,7 +443,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
Args: Args:
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det', 'seg'. 'lidar_det', 'multi-modality_det', 'mono_det', 'lidar_seg'.
out_file (str): Output file path. out_file (str): Output file path.
drawn_img (np.ndarray, optional): The image to show. If drawn_img drawn_img (np.ndarray, optional): The image to show. If drawn_img
is None, it will show the image got by Visualizer. Defaults is None, it will show the image got by Visualizer. Defaults
...@@ -450,13 +454,13 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -450,13 +454,13 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
continue_key (str): The key for users to continue. Defaults to continue_key (str): The key for users to continue. Defaults to
the space key. the space key.
""" """
if vis_task in ['det', 'multi_modality-det', 'seg']: if vis_task in ['lidar_det', 'multi-modality_det', 'lidar_seg']:
self.o3d_vis.run() self.o3d_vis.run()
if out_file is not None: if out_file is not None:
self.o3d_vis.capture_screen_image(out_file + '.png') self.o3d_vis.capture_screen_image(out_file + '.png')
self.o3d_vis.destroy_window() self.o3d_vis.destroy_window()
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono_det', 'multi-modality_det']:
super().show(drawn_img_3d, win_name, wait_time, continue_key) super().show(drawn_img_3d, win_name, wait_time, continue_key)
if drawn_img is not None: if drawn_img is not None:
...@@ -474,7 +478,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -474,7 +478,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
show: bool = False, show: bool = False,
wait_time: float = 0, wait_time: float = 0,
out_file: Optional[str] = None, out_file: Optional[str] = None,
vis_task: str = 'mono-det', vis_task: str = 'mono_det',
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
step: int = 0) -> None: step: int = 0) -> None:
"""Draw datasample and save to all backends. """Draw datasample and save to all backends.
...@@ -502,7 +506,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -502,7 +506,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
image. Default to False. image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None. out_file (str): Path to output file. Defaults to None.
vis-task (str): Visualization task. Defaults to 'mono-det'. vis-task (str): Visualization task. Defaults to 'mono_det'.
pred_score_thr (float): The threshold to visualize the bboxes pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3. and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0. step (int): Global step value to record. Defaults to 0.
...@@ -564,7 +568,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -564,7 +568,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
pred_img_data = self._draw_instances( pred_img_data = self._draw_instances(
img, pred_instances, classes, palette) img, pred_instances, classes, palette)
if 'pred_pts_seg' in data_sample: if 'pred_pts_seg' in data_sample and vis_task == 'lidar_seg':
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
......
...@@ -17,8 +17,9 @@ import mmengine ...@@ -17,8 +17,9 @@ import mmengine
import numpy as np import numpy as np
from nuscenes.nuscenes import NuScenes from nuscenes.nuscenes import NuScenes
from mmdet3d.datasets.convert_utils import (convert_annos, get_2d_boxes, from mmdet3d.datasets.convert_utils import (convert_annos,
get_waymo_2d_boxes) get_kitti_style_2d_boxes,
get_nuscenes_2d_boxes)
from mmdet3d.datasets.utils import convert_quaternion_to_matrix from mmdet3d.datasets.utils import convert_quaternion_to_matrix
from mmdet3d.structures import points_cam2img from mmdet3d.structures import points_cam2img
...@@ -218,7 +219,7 @@ def clear_data_info_unused_keys(data_info): ...@@ -218,7 +219,7 @@ def clear_data_info_unused_keys(data_info):
return data_info, empty_flag return data_info, empty_flag
def generate_camera_instances(info, nusc): def generate_nuscenes_camera_instances(info, nusc):
# get bbox annotations for camera # get bbox annotations for camera
camera_types = [ camera_types = [
...@@ -235,7 +236,7 @@ def generate_camera_instances(info, nusc): ...@@ -235,7 +236,7 @@ def generate_camera_instances(info, nusc):
for cam in camera_types: for cam in camera_types:
cam_info = info['cams'][cam] cam_info = info['cams'][cam]
# list[dict] # list[dict]
ann_infos = get_2d_boxes( ann_infos = get_nuscenes_2d_boxes(
nusc, nusc,
cam_info['sample_data_token'], cam_info['sample_data_token'],
visibilities=['', '1', '2', '3', '4']) visibilities=['', '1', '2', '3', '4'])
...@@ -357,7 +358,7 @@ def update_nuscenes_infos(pkl_path, out_dir): ...@@ -357,7 +358,7 @@ def update_nuscenes_infos(pkl_path, out_dir):
empty_instance['bbox_3d_isvalid'] = ori_info_dict['valid_flag'][i] empty_instance['bbox_3d_isvalid'] = ori_info_dict['valid_flag'][i]
empty_instance = clear_instance_unused_keys(empty_instance) empty_instance = clear_instance_unused_keys(empty_instance)
temp_data_info['instances'].append(empty_instance) temp_data_info['instances'].append(empty_instance)
temp_data_info['cam_instances'] = generate_camera_instances( temp_data_info['cam_instances'] = generate_nuscenes_camera_instances(
ori_info_dict, nusc) ori_info_dict, nusc)
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info) temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info) converted_list.append(temp_data_info)
...@@ -487,6 +488,8 @@ def update_kitti_infos(pkl_path, out_dir): ...@@ -487,6 +488,8 @@ def update_kitti_infos(pkl_path, out_dir):
empty_instance = clear_instance_unused_keys(empty_instance) empty_instance = clear_instance_unused_keys(empty_instance)
instance_list.append(empty_instance) instance_list.append(empty_instance)
temp_data_info['instances'] = instance_list temp_data_info['instances'] = instance_list
cam_instances = generate_kitti_camera_instances(ori_info_dict)
temp_data_info['cam_instances'] = cam_instances
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info) temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info) converted_list.append(temp_data_info)
pkl_name = pkl_path.split('/')[-1] pkl_name = pkl_path.split('/')[-1]
...@@ -997,6 +1000,18 @@ def update_waymo_infos(pkl_path, out_dir): ...@@ -997,6 +1000,18 @@ def update_waymo_infos(pkl_path, out_dir):
mmengine.dump(converted_data_info, out_path, 'pkl') mmengine.dump(converted_data_info, out_path, 'pkl')
def generate_kitti_camera_instances(ori_info_dict):
cam_key = 'CAM2'
empty_camera_instances = get_empty_multicamera_instances([cam_key])
annos = copy.deepcopy(ori_info_dict['annos'])
ann_infos = get_kitti_style_2d_boxes(
ori_info_dict, occluded=[0, 1, 2, 3], annos=annos)
empty_camera_instances[cam_key] = ann_infos
return empty_camera_instances
def generate_waymo_camera_instances(ori_info_dict, cam_keys): def generate_waymo_camera_instances(ori_info_dict, cam_keys):
empty_multicamera_instances = get_empty_multicamera_instances(cam_keys) empty_multicamera_instances = get_empty_multicamera_instances(cam_keys)
...@@ -1006,8 +1021,8 @@ def generate_waymo_camera_instances(ori_info_dict, cam_keys): ...@@ -1006,8 +1021,8 @@ def generate_waymo_camera_instances(ori_info_dict, cam_keys):
if cam_idx != 0: if cam_idx != 0:
annos = convert_annos(ori_info_dict, cam_idx) annos = convert_annos(ori_info_dict, cam_idx)
ann_infos = get_waymo_2d_boxes( ann_infos = get_kitti_style_2d_boxes(
ori_info_dict, cam_idx, occluded=[0], annos=annos) ori_info_dict, cam_idx, occluded=[0], annos=annos, dataset='waymo')
empty_multicamera_instances[cam_key] = ann_infos empty_multicamera_instances[cam_key] = ann_infos
return empty_multicamera_instances return empty_multicamera_instances
...@@ -1019,7 +1034,7 @@ def parse_args(): ...@@ -1019,7 +1034,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', type=str, default='kitti', help='name of dataset') '--dataset', type=str, default='kitti', help='name of dataset')
parser.add_argument( parser.add_argument(
'--pkl', '--pkl-path',
type=str, type=str,
default='./data/kitti/kitti_infos_train.pkl ', default='./data/kitti/kitti_infos_train.pkl ',
help='specify the root dir of dataset') help='specify the root dir of dataset')
......
...@@ -26,7 +26,10 @@ def parse_args(): ...@@ -26,7 +26,10 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--task', '--task',
type=str, type=str,
choices=['det', 'seg', 'multi_modality-det', 'mono-det'], choices=[
'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg',
'multi-modality_det'
],
help='Determine the visualization method depending on the task.') help='Determine the visualization method depending on the task.')
parser.add_argument( parser.add_argument(
'--aug', '--aug',
...@@ -107,7 +110,7 @@ def main(): ...@@ -107,7 +110,7 @@ def main():
dataset = DATASETS.build(cfg.train_dataloader.dataset) dataset = DATASETS.build(cfg.train_dataloader.dataset)
# configure visualization mode # configure visualization mode
vis_task = args.task # 'det', 'seg', 'multi_modality-det', 'mono-det' vis_task = args.task
visualizer = VISUALIZERS.build(cfg.visualizer) visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.metainfo visualizer.dataset_meta = dataset.metainfo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment