Unverified Commit 6c03a971 authored by Tai-Wang's avatar Tai-Wang Committed by GitHub
Browse files

Release v1.1.0rc1

Release v1.1.0rc1
parents 9611c2d0 ca42c312
...@@ -102,10 +102,10 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py - ...@@ -102,10 +102,10 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py -
python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --aug --output-dir ${OUTPUT_DIR} --online python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --aug --output-dir ${OUTPUT_DIR} --online
``` ```
如果您还想显示 2D 图像以及投影的 3D 边界框,则需要找到支持多模态数据加载的配置文件,然后将 `--task` 参数更改为 `multi_modality-det`。一个例子如下所示 如果您还想显示 2D 图像以及投影的 3D 边界框,则需要找到支持多模态数据加载的配置文件,然后将 `--task` 参数更改为 `multi-modality_det`。一个例子如下所示
```shell ```shell
python tools/misc/browse_dataset.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py --task multi_modality-det --output-dir ${OUTPUT_DIR} --online python tools/misc/browse_dataset.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py --task multi-modality_det --output-dir ${OUTPUT_DIR} --online
``` ```
![](../../resources/browse_dataset_multi_modality.png) ![](../../resources/browse_dataset_multi_modality.png)
...@@ -121,7 +121,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --tas ...@@ -121,7 +121,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --tas
在单目 3D 检测任务中浏览 nuScenes 数据集 在单目 3D 检测任务中浏览 nuScenes 数据集
```shell ```shell
python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono-det --output-dir ${OUTPUT_DIR} --online python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono_det --output-dir ${OUTPUT_DIR} --online
``` ```
![](../../resources/browse_dataset_mono.png) ![](../../resources/browse_dataset_mono.png)
......
...@@ -143,6 +143,7 @@ def inference_detector(model: nn.Module, ...@@ -143,6 +143,7 @@ def inference_detector(model: nn.Module,
# load from point cloud file # load from point cloud file
data_ = dict( data_ = dict(
lidar_points=dict(lidar_path=pcd), lidar_points=dict(lidar_path=pcd),
timestamp=1,
# for ScanNet demo we need axis_align_matrix # for ScanNet demo we need axis_align_matrix
axis_align_matrix=np.eye(4), axis_align_matrix=np.eye(4),
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
...@@ -151,6 +152,7 @@ def inference_detector(model: nn.Module, ...@@ -151,6 +152,7 @@ def inference_detector(model: nn.Module,
# directly use loaded point cloud # directly use loaded point cloud
data_ = dict( data_ = dict(
points=pcd, points=pcd,
timestamp=1,
# for ScanNet demo we need axis_align_matrix # for ScanNet demo we need axis_align_matrix
axis_align_matrix=np.eye(4), axis_align_matrix=np.eye(4),
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS, PIPELINES, build_dataset from .builder import DATASETS, PIPELINES, build_dataset
from .convert_utils import get_2d_boxes
from .dataset_wrappers import CBGSDataset from .dataset_wrappers import CBGSDataset
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
...@@ -22,8 +21,8 @@ from .transforms import (AffineResize, BackgroundPointsFilter, GlobalAlignment, ...@@ -22,8 +21,8 @@ from .transforms import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
ObjectNameFilter, ObjectNoise, ObjectRangeFilter, ObjectNameFilter, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointSample, PointShuffle, ObjectSample, PointSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomShiftScale, RandomFlip3D, RandomJitterPoints, RandomResize3D,
VoxelBasedPointSampler) RandomShiftScale, Resize3D, VoxelBasedPointSampler)
from .utils import get_loading_pipeline from .utils import get_loading_pipeline
from .waymo_dataset import WaymoDataset from .waymo_dataset import WaymoDataset
...@@ -40,5 +39,6 @@ __all__ = [ ...@@ -40,5 +39,6 @@ __all__ = [
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor',
'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize', 'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'PIPELINES', 'get_2d_boxes' 'RandomShiftScale', 'LoadPointsFromDict', 'PIPELINES',
'Resize3D', 'RandomResize3D',
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from nuscenes.utils.geometry_utils import view_points from nuscenes.utils.geometry_utils import view_points
...@@ -11,6 +11,11 @@ from shapely.geometry import MultiPoint, box ...@@ -11,6 +11,11 @@ from shapely.geometry import MultiPoint, box
from mmdet3d.structures import Box3DMode, CameraInstance3DBoxes, points_cam2img from mmdet3d.structures import Box3DMode, CameraInstance3DBoxes, points_cam2img
from mmdet3d.structures.ops import box_np_ops from mmdet3d.structures.ops import box_np_ops
kitti_categories = ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc')
waymo_categories = ('Car', 'Pedestrian', 'Cyclist')
nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier') 'barrier')
...@@ -48,8 +53,10 @@ LyftNameMapping = { ...@@ -48,8 +53,10 @@ LyftNameMapping = {
} }
def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): def get_nuscenes_2d_boxes(nusc, sample_data_token: str,
"""Get the 2D annotation records for a given `sample_data_token`. visibilities: List[str]):
"""Get the 2d / mono3d annotation records for a given `sample_data_token of
nuscenes dataset.
Args: Args:
sample_data_token (str): Sample data token belonging to a camera sample_data_token (str): Sample data token belonging to a camera
...@@ -57,7 +64,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): ...@@ -57,7 +64,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]):
visibilities (list[str]): Visibility filter. visibilities (list[str]): Visibility filter.
Return: Return:
list[dict]: List of 2D annotation record that belongs to the input list[dict]: List of 2d annotation record that belongs to the input
`sample_data_token`. `sample_data_token`.
""" """
...@@ -128,7 +135,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): ...@@ -128,7 +135,7 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]):
# Generate dictionary record to be included in the .json file. # Generate dictionary record to be included in the .json file.
repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y, repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y,
sample_data_token, sd_rec['filename']) 'nuscenes')
# if repro_rec is None, we do not append it into repre_recs # if repro_rec is None, we do not append it into repre_recs
if repro_rec is not None: if repro_rec is not None:
...@@ -178,23 +185,36 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]): ...@@ -178,23 +185,36 @@ def get_2d_boxes(nusc, sample_data_token: str, visibilities: List[str]):
return repro_recs return repro_recs
def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): def get_kitti_style_2d_boxes(info: dict,
"""Get the 2D annotation records for a given info. cam_idx: int = 2,
occluded: Tuple[int] = (0, 1, 2, 3),
annos: Optional[dict] = None,
mono3d: bool = True,
dataset: str = 'kitti'):
"""Get the 2d / mono3d annotation records for a given info.
This function is used to get 2D annotations when loading annotations from This function is used to get 2D/Mono3D annotations when loading annotations
a dataset class. The original version in the data converter will be from a kitti-style dataset class, such as KITTI and Waymo dataset.
deprecated in the future.
Args: Args:
info: Information of the given sample data. info (dict): Information of the given sample data.
occluded: Integer (0, 1, 2, 3) indicating occlusion state: cam_idx (int): Camera id which the 2d / mono3d annotations to obtain
belong to. In KITTI, typically only CAM 2 will be used,
and in Waymo, multi cameras could be used.
Defaults to 2.
occluded (tuple[int]): Integer (0, 1, 2, 3) indicating occlusion state:
0 = fully visible, 1 = partly occluded, 2 = largely occluded, 0 = fully visible, 1 = partly occluded, 2 = largely occluded,
3 = unknown, -1 = DontCare 3 = unknown, -1 = DontCare.
Defaults to (0, 1, 2, 3).
annos (dict, optional): Original annotations.
mono3d (bool): Whether to get boxes with mono3d annotation. mono3d (bool): Whether to get boxes with mono3d annotation.
Defaults to True.
dataset (str): Dataset name of getting 2d bboxes.
Defaults to `kitti`.
Return: Return:
list[dict]: List of 2D annotation record that belongs to the input list[dict]: List of 2d / mono3d annotation record that
`sample_data_token`. belongs to the input camera id.
""" """
# Get calibration information # Get calibration information
camera_intrinsic = info['calib'][f'P{cam_idx}'] camera_intrinsic = info['calib'][f'P{cam_idx}']
...@@ -224,7 +244,6 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): ...@@ -224,7 +244,6 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True):
ann_rec['sample_annotation_token'] = \ ann_rec['sample_annotation_token'] = \
f"{info['image']['image_idx']}.{ann_idx}" f"{info['image']['image_idx']}.{ann_idx}"
ann_rec['sample_data_token'] = info['image']['image_idx'] ann_rec['sample_data_token'] = info['image']['image_idx']
sample_data_token = info['image']['image_idx']
loc = ann_rec['location'][np.newaxis, :] loc = ann_rec['location'][np.newaxis, :]
dim = ann_rec['dimensions'][np.newaxis, :] dim = ann_rec['dimensions'][np.newaxis, :]
...@@ -266,9 +285,8 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): ...@@ -266,9 +285,8 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True):
min_x, min_y, max_x, max_y = final_coords min_x, min_y, max_x, max_y = final_coords
# Generate dictionary record to be included in the .json file. # Generate dictionary record to be included in the .json file.
repro_rec = generate_waymo_mono3d_record(ann_rec, min_x, min_y, max_x, repro_rec = generate_record(ann_rec, min_x, min_y, max_x, max_y,
max_y, sample_data_token, dataset)
info['image']['image_path'])
# If mono3d=True, add 3D annotations in camera coordinates # If mono3d=True, add 3D annotations in camera coordinates
if mono3d and (repro_rec is not None): if mono3d and (repro_rec is not None):
...@@ -288,11 +306,7 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True): ...@@ -288,11 +306,7 @@ def get_waymo_2d_boxes(info, cam_idx, occluded, annos=None, mono3d=True):
# samples with depth < 0 will be removed # samples with depth < 0 will be removed
if repro_rec['depth'] <= 0: if repro_rec['depth'] <= 0:
continue continue
repro_recs.append(repro_rec)
repro_rec['attribute_name'] = -1 # no attribute in KITTI
repro_rec['attribute_id'] = -1
repro_recs.append(repro_rec)
return repro_recs return repro_recs
...@@ -355,7 +369,7 @@ def post_process_coords( ...@@ -355,7 +369,7 @@ def post_process_coords(
def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float, def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
sample_data_token: str, filename: str) -> OrderedDict: dataset: str) -> OrderedDict:
"""Generate one 2D annotation record given various information on top of """Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates. the 2D bounding box coordinates.
...@@ -365,112 +379,40 @@ def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float, ...@@ -365,112 +379,40 @@ def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
y1 (float): Minimum value of the y coordinate. y1 (float): Minimum value of the y coordinate.
x2 (float): Maximum value of the x coordinate. x2 (float): Maximum value of the x coordinate.
y2 (float): Maximum value of the y coordinate. y2 (float): Maximum value of the y coordinate.
sample_data_token (str): Sample data token. dataset (str): Name of dataset.
filename (str):The corresponding image file where the annotation
is present.
Returns: Returns:
dict: A sample mono3D annotation record. dict: A sample 2d annotation record.
- bbox_label (int): 2d box label id - bbox_label (int): 2d box label id
- bbox_label_3d (int): 3d box label id - bbox_label_3d (int): 3d box label id
- bbox (list[float]): left x, top y, right x, bottom y - bbox (list[float]): left x, top y, right x, bottom y
of 2d box of 2d box
- bbox_3d_isvalid (bool): whether the box is valid - bbox_3d_isvalid (bool): whether the box is valid
""" """
repro_rec = OrderedDict()
repro_rec['sample_data_token'] = sample_data_token
coco_rec = dict()
relevant_keys = [
'attribute_tokens',
'category_name',
'instance_token',
'next',
'num_lidar_pts',
'num_radar_pts',
'prev',
'sample_annotation_token',
'sample_data_token',
'visibility_token',
]
for key, value in ann_rec.items(): if dataset == 'nuscenes':
if key in relevant_keys: cat_name = ann_rec['category_name']
repro_rec[key] = value if cat_name not in NuScenesNameMapping:
return None
repro_rec['bbox_corners'] = [x1, y1, x2, y2] else:
repro_rec['filename'] = filename cat_name = NuScenesNameMapping[cat_name]
categories = nus_categories
if repro_rec['category_name'] not in NuScenesNameMapping: else:
return None cat_name = ann_rec['name']
cat_name = NuScenesNameMapping[repro_rec['category_name']] if cat_name not in categories:
coco_rec['bbox_label'] = nus_categories.index(cat_name) return None
coco_rec['bbox_label_3d'] = nus_categories.index(cat_name)
coco_rec['bbox'] = [x1, y1, x2, y2] if dataset == 'kitti':
coco_rec['bbox_3d_isvalid'] = True categories = kitti_categories
elif dataset == 'waymo':
return coco_rec categories = waymo_categories
else:
raise NotImplementedError('Unsupported dataset!')
def generate_waymo_mono3d_record(ann_rec, x1, y1, x2, y2, sample_data_token,
filename):
"""Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates.
The original version in the data converter will be deprecated in the
future.
Args: rec = dict()
ann_rec (dict): Original 3d annotation record. rec['bbox_label'] = categories.index(cat_name)
x1 (float): Minimum value of the x coordinate. rec['bbox_label_3d'] = rec['bbox_label']
y1 (float): Minimum value of the y coordinate. rec['bbox'] = [x1, y1, x2, y2]
x2 (float): Maximum value of the x coordinate. rec['bbox_3d_isvalid'] = True
y2 (float): Maximum value of the y coordinate.
sample_data_token (str): Sample data token.
filename (str):The corresponding image file where the annotation
is present.
Returns: return rec
dict: A sample 2D annotation record.
- file_name (str): file name
- image_id (str): sample data token
- area (float): 2d box area
- category_name (str): category name
- category_id (int): category id
- bbox (list[float]): left x, top y, x_size, y_size of 2d box
- iscrowd (int): whether the area is crowd
"""
kitti_categories = ('Car', 'Pedestrian', 'Cyclist')
repro_rec = OrderedDict()
repro_rec['sample_data_token'] = sample_data_token
coco_rec = dict()
key_mapping = {
'name': 'category_name',
'num_points_in_gt': 'num_lidar_pts',
'sample_annotation_token': 'sample_annotation_token',
'sample_data_token': 'sample_data_token',
}
for key, value in ann_rec.items():
if key in key_mapping.keys():
repro_rec[key_mapping[key]] = value
repro_rec['bbox_corners'] = [x1, y1, x2, y2]
repro_rec['filename'] = filename
coco_rec['file_name'] = filename
coco_rec['image_id'] = sample_data_token
coco_rec['area'] = (y2 - y1) * (x2 - x1)
if repro_rec['category_name'] not in kitti_categories:
return None
cat_name = repro_rec['category_name']
coco_rec['category_name'] = cat_name
coco_rec['category_id'] = kitti_categories.index(cat_name)
coco_rec['bbox_label'] = coco_rec['category_id']
coco_rec['bbox_label_3d'] = coco_rec['bbox_label']
coco_rec['bbox'] = [x1, y1, x2 - x1, y2 - y1]
coco_rec['iscrowd'] = 0
return coco_rec
...@@ -26,11 +26,11 @@ class Det3DDataset(BaseDataset): ...@@ -26,11 +26,11 @@ class Det3DDataset(BaseDataset):
metainfo (dict, optional): Meta information for dataset, such as class metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None. information. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to data_prefix (dict, optional): Prefix for training data. Defaults to
dict(pts='velodyne', img=""). dict(pts='velodyne', img='').
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
as input, it usually has following keys. as input, it usually has following keys:
- use_camera: bool - use_camera: bool
- use_lidar: bool - use_lidar: bool
...@@ -40,7 +40,7 @@ class Det3DDataset(BaseDataset): ...@@ -40,7 +40,7 @@ class Det3DDataset(BaseDataset):
box_type_3d (str, optional): Type of 3D box of this dataset. box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR'. Available options includes Defaults to 'LiDAR'. Available options includes:
- 'LiDAR': Box in LiDAR coordinates, usually for - 'LiDAR': Box in LiDAR coordinates, usually for
outdoor point cloud 3d detection. outdoor point cloud 3d detection.
...@@ -49,15 +49,15 @@ class Det3DDataset(BaseDataset): ...@@ -49,15 +49,15 @@ class Det3DDataset(BaseDataset):
- 'Camera': Box in camera coordinates, usually - 'Camera': Box in camera coordinates, usually
for vision-based 3d detection. for vision-based 3d detection.
filter_empty_gt (bool): Whether to filter the data with filter_empty_gt (bool, optional): Whether to filter the data with
empty GT. Defaults to True. empty GT. Defaults to True.
test_mode (bool): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
load_eval_anns (bool): Whether to load annotations load_eval_anns (bool, optional): Whether to load annotations
in test_mode, the annotation will be save in in test_mode, the annotation will be save in `eval_ann_infos`,
`eval_ann_infos`, which can be use in Evaluator. which can be used in Evaluator. Defaults to True.
file_client_args (dict): Configuration of file client. file_client_args (dict, optional): Configuration of file client.
Defaults to `dict(backend='disk')`. Defaults to dict(backend='disk').
""" """
def __init__(self, def __init__(self,
...@@ -73,7 +73,7 @@ class Det3DDataset(BaseDataset): ...@@ -73,7 +73,7 @@ class Det3DDataset(BaseDataset):
test_mode: bool = False, test_mode: bool = False,
load_eval_anns=True, load_eval_anns=True,
file_client_args: dict = dict(backend='disk'), file_client_args: dict = dict(backend='disk'),
**kwargs): **kwargs) -> None:
# init file client # init file client
self.file_client = mmengine.FileClient(**file_client_args) self.file_client = mmengine.FileClient(**file_client_args)
self.filter_empty_gt = filter_empty_gt self.filter_empty_gt = filter_empty_gt
...@@ -125,7 +125,7 @@ class Det3DDataset(BaseDataset): ...@@ -125,7 +125,7 @@ class Det3DDataset(BaseDataset):
self.metainfo['box_type_3d'] = box_type_3d self.metainfo['box_type_3d'] = box_type_3d
self.metainfo['label_mapping'] = self.label_mapping self.metainfo['label_mapping'] = self.label_mapping
def _remove_dontcare(self, ann_info): def _remove_dontcare(self, ann_info: dict) -> dict:
"""Remove annotations that do not need to be cared. """Remove annotations that do not need to be cared.
-1 indicate dontcare in MMDet3d. -1 indicate dontcare in MMDet3d.
...@@ -192,7 +192,8 @@ class Det3DDataset(BaseDataset): ...@@ -192,7 +192,8 @@ class Det3DDataset(BaseDataset):
'bbox_3d': 'gt_bboxes_3d', 'bbox_3d': 'gt_bboxes_3d',
'depth': 'depths', 'depth': 'depths',
'center_2d': 'centers_2d', 'center_2d': 'centers_2d',
'attr_label': 'attr_labels' 'attr_label': 'attr_labels',
'velocity': 'velocities',
} }
instances = info['instances'] instances = info['instances']
# empty gt # empty gt
...@@ -209,14 +210,18 @@ class Det3DDataset(BaseDataset): ...@@ -209,14 +210,18 @@ class Det3DDataset(BaseDataset):
self.label_mapping[item] for item in temp_anns self.label_mapping[item] for item in temp_anns
] ]
if ann_name in name_mapping: if ann_name in name_mapping:
ann_name = name_mapping[ann_name] mapped_ann_name = name_mapping[ann_name]
else:
mapped_ann_name = ann_name
if 'label' in ann_name: if 'label' in ann_name:
temp_anns = np.array(temp_anns).astype(np.int64) temp_anns = np.array(temp_anns).astype(np.int64)
else: elif ann_name in name_mapping:
temp_anns = np.array(temp_anns).astype(np.float32) temp_anns = np.array(temp_anns).astype(np.float32)
else:
temp_anns = np.array(temp_anns)
ann_info[ann_name] = temp_anns ann_info[mapped_ann_name] = temp_anns
ann_info['instances'] = info['instances'] ann_info['instances'] = info['instances']
return ann_info return ann_info
...@@ -241,6 +246,7 @@ class Det3DDataset(BaseDataset): ...@@ -241,6 +246,7 @@ class Det3DDataset(BaseDataset):
self.data_prefix.get('pts', ''), self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path']) info['lidar_points']['lidar_path'])
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path'] info['lidar_path'] = info['lidar_points']['lidar_path']
if 'lidar_sweeps' in info: if 'lidar_sweeps' in info:
for sweep in info['lidar_sweeps']: for sweep in info['lidar_sweeps']:
...@@ -285,7 +291,7 @@ class Det3DDataset(BaseDataset): ...@@ -285,7 +291,7 @@ class Det3DDataset(BaseDataset):
return info return info
def prepare_data(self, index): def prepare_data(self, index: int) -> Optional[dict]:
"""Data preparation for both training and testing stage. """Data preparation for both training and testing stage.
Called by `__getitem__` of dataset. Called by `__getitem__` of dataset.
...@@ -294,7 +300,7 @@ class Det3DDataset(BaseDataset): ...@@ -294,7 +300,7 @@ class Det3DDataset(BaseDataset):
index (int): Index for accessing the target data. index (int): Index for accessing the target data.
Returns: Returns:
dict: Data dict of the corresponding index. dict | None: Data dict of the corresponding index.
""" """
input_dict = self.get_data_info(index) input_dict = self.get_data_info(index)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union from typing import Callable, List, Union
import numpy as np import numpy as np
...@@ -22,11 +22,12 @@ class KittiDataset(Det3DDataset): ...@@ -22,11 +22,12 @@ class KittiDataset(Det3DDataset):
Defaults to None. Defaults to None.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to `dict(use_lidar=True)`. as input. Defaults to `dict(use_lidar=True)`.
default_cam_key (str, optional): The default camera name adopted.
Defaults to 'CAM2'.
box_type_3d (str, optional): Type of 3D box of this dataset. box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes Defaults to 'LiDAR' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates. - 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset. - 'Depth': Box in depth coordinates, usually for indoor dataset.
...@@ -35,9 +36,9 @@ class KittiDataset(Det3DDataset): ...@@ -35,9 +36,9 @@ class KittiDataset(Det3DDataset):
Defaults to True. Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
pcd_limit_range (list, optional): The range of point cloud used to pcd_limit_range (list[float], optional): The range of point cloud
filter invalid predicted boxes. used to filter invalid predicted boxes.
Default: [0, -40, -3, 70.4, 40, 0.0]. Defaults to [0, -40, -3, 70.4, 40, 0.0].
""" """
# TODO: use full classes of kitti # TODO: use full classes of kitti
METAINFO = { METAINFO = {
...@@ -49,15 +50,18 @@ class KittiDataset(Det3DDataset): ...@@ -49,15 +50,18 @@ class KittiDataset(Det3DDataset):
data_root: str, data_root: str,
ann_file: str, ann_file: str,
pipeline: List[Union[dict, Callable]] = [], pipeline: List[Union[dict, Callable]] = [],
modality: Optional[dict] = dict(use_lidar=True), modality: dict = dict(use_lidar=True),
default_cam_key: str = 'CAM2', default_cam_key: str = 'CAM2',
task: str = 'lidar_det',
box_type_3d: str = 'LiDAR', box_type_3d: str = 'LiDAR',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0], pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0],
**kwargs): **kwargs) -> None:
self.pcd_limit_range = pcd_limit_range self.pcd_limit_range = pcd_limit_range
assert task in ('lidar_det', 'mono_det')
self.task = task
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_file, ann_file=ann_file,
...@@ -107,11 +111,14 @@ class KittiDataset(Det3DDataset): ...@@ -107,11 +111,14 @@ class KittiDataset(Det3DDataset):
info['plane'] = plane_lidar info['plane'] = plane_lidar
if self.task == 'mono_det':
info['instances'] = info['cam_instances'][self.default_cam_key]
info = super().parse_data_info(info) info = super().parse_data_info(info)
return info return info
def parse_ann_info(self, info): def parse_ann_info(self, info: dict) -> dict:
"""Get annotation info according to the given index. """Get annotation info according to the given index.
Args: Args:
...@@ -135,6 +142,12 @@ class KittiDataset(Det3DDataset): ...@@ -135,6 +142,12 @@ class KittiDataset(Det3DDataset):
ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32) ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64) ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
if self.task == 'mono_det':
ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
ann_info['depths'] = np.zeros((0), dtype=np.float32)
ann_info = self._remove_dontcare(ann_info) ann_info = self._remove_dontcare(ann_info)
# in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam # in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam
lidar2cam = np.array(info['images']['CAM2']['lidar2cam']) lidar2cam = np.array(info['images']['CAM2']['lidar2cam'])
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List from typing import Callable, List, Union
import numpy as np import numpy as np
...@@ -24,18 +24,18 @@ class LyftDataset(Det3DDataset): ...@@ -24,18 +24,18 @@ class LyftDataset(Det3DDataset):
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None. as input. Defaults to dict(use_camera=False, use_lidar=True).
box_type_3d (str): Type of 3D box of this dataset. box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes Defaults to 'LiDAR' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates. - 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset. - 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates. - 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter empty GT. filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True. Defaults to True.
test_mode (bool): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
""" """
...@@ -48,8 +48,8 @@ class LyftDataset(Det3DDataset): ...@@ -48,8 +48,8 @@ class LyftDataset(Det3DDataset):
def __init__(self, def __init__(self,
data_root: str, data_root: str,
ann_file: str, ann_file: str,
pipeline: List[dict] = None, pipeline: List[Union[dict, Callable]] = [],
modality: Dict = dict(use_camera=False, use_lidar=True), modality: dict = dict(use_camera=False, use_lidar=True),
box_type_3d: str = 'LiDAR', box_type_3d: str = 'LiDAR',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp from os import path as osp
from typing import Dict, List from typing import Callable, List, Union
import numpy as np import numpy as np
...@@ -22,25 +22,26 @@ class NuScenesDataset(Det3DDataset): ...@@ -22,25 +22,26 @@ class NuScenesDataset(Det3DDataset):
Args: Args:
data_root (str): Path of dataset root. data_root (str): Path of dataset root.
ann_file (str): Path of annotation file. ann_file (str): Path of annotation file.
task (str, optional): Detection task. Defaults to 'lidar_det'.
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
box_type_3d (str): Type of 3D box of this dataset. box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes. Defaults to 'LiDAR' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates. - 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset. - 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates. - 'Camera': Box in camera coordinates.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to dict(use_camera=False,use_lidar=True). as input. Defaults to dict(use_camera=False, use_lidar=True).
filter_empty_gt (bool): Whether to filter empty GT. filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True. Defaults to True.
test_mode (bool): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
with_velocity (bool): Whether include velocity prediction with_velocity (bool, optional): Whether to include velocity prediction
into the experiments. Defaults to True. into the experiments. Defaults to True.
use_valid_flag (bool): Whether to use `use_valid_flag` key use_valid_flag (bool, optional): Whether to use `use_valid_flag` key
in the info file as mask to filter gt_boxes and gt_names. in the info file as mask to filter gt_boxes and gt_names.
Defaults to False. Defaults to False.
""" """
...@@ -55,10 +56,10 @@ class NuScenesDataset(Det3DDataset): ...@@ -55,10 +56,10 @@ class NuScenesDataset(Det3DDataset):
def __init__(self, def __init__(self,
data_root: str, data_root: str,
ann_file: str, ann_file: str,
task: str = '3d', task: str = 'lidar_det',
pipeline: List[dict] = None, pipeline: List[Union[dict, Callable]] = [],
box_type_3d: str = 'LiDAR', box_type_3d: str = 'LiDAR',
modality: Dict = dict( modality: dict = dict(
use_camera=False, use_camera=False,
use_lidar=True, use_lidar=True,
), ),
...@@ -66,12 +67,12 @@ class NuScenesDataset(Det3DDataset): ...@@ -66,12 +67,12 @@ class NuScenesDataset(Det3DDataset):
test_mode: bool = False, test_mode: bool = False,
with_velocity: bool = True, with_velocity: bool = True,
use_valid_flag: bool = False, use_valid_flag: bool = False,
**kwargs): **kwargs) -> None:
self.use_valid_flag = use_valid_flag self.use_valid_flag = use_valid_flag
self.with_velocity = with_velocity self.with_velocity = with_velocity
# TODO: Redesign multi-view data process in the future # TODO: Redesign multi-view data process in the future
assert task in ('3d', 'mono3d', 'multi-view') assert task in ('lidar_det', 'mono_det', 'multi-view_det')
self.task = task self.task = task
assert box_type_3d.lower() in ('lidar', 'camera') assert box_type_3d.lower() in ('lidar', 'camera')
...@@ -85,6 +86,27 @@ class NuScenesDataset(Det3DDataset): ...@@ -85,6 +86,27 @@ class NuScenesDataset(Det3DDataset):
test_mode=test_mode, test_mode=test_mode,
**kwargs) **kwargs)
def _filter_with_mask(self, ann_info: dict) -> dict:
"""Remove annotations that do not need to be cared.
Args:
ann_info (dict): Dict of annotation infos.
Returns:
dict: Annotations after filtering.
"""
filtered_annotations = {}
if self.use_valid_flag:
filter_mask = ann_info['bbox_3d_isvalid']
else:
filter_mask = ann_info['num_lidar_pts'] > 0
for key in ann_info.keys():
if key != 'instances':
filtered_annotations[key] = (ann_info[key][filter_mask])
else:
filtered_annotations[key] = ann_info[key]
return filtered_annotations
def parse_ann_info(self, info: dict) -> dict: def parse_ann_info(self, info: dict) -> dict:
"""Get annotation info according to the given index. """Get annotation info according to the given index.
...@@ -99,66 +121,51 @@ class NuScenesDataset(Det3DDataset): ...@@ -99,66 +121,51 @@ class NuScenesDataset(Det3DDataset):
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
""" """
ann_info = super().parse_ann_info(info) ann_info = super().parse_ann_info(info)
if ann_info is None: if ann_info is not None:
# empty instance
anns_results = dict() ann_info = self._filter_with_mask(ann_info)
anns_results['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
anns_results['gt_labels_3d'] = np.zeros(0, dtype=np.int64) if self.with_velocity:
return anns_results gt_bboxes_3d = ann_info['gt_bboxes_3d']
gt_velocities = ann_info['velocities']
if self.use_valid_flag: nan_mask = np.isnan(gt_velocities[:, 0])
mask = ann_info['bbox_3d_isvalid'] gt_velocities[nan_mask] = [0.0, 0.0]
else: gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocities],
mask = ann_info['num_lidar_pts'] > 0 axis=-1)
gt_bboxes_3d = ann_info['gt_bboxes_3d'][mask] ann_info['gt_bboxes_3d'] = gt_bboxes_3d
gt_labels_3d = ann_info['gt_labels_3d'][mask]
if 'gt_bboxes' in ann_info:
gt_bboxes = ann_info['gt_bboxes'][mask]
gt_labels = ann_info['gt_labels'][mask]
attr_labels = ann_info['attr_labels'][mask]
else: else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32) # empty instance
gt_labels = np.array([], dtype=np.int64) ann_info = dict()
attr_labels = np.array([], dtype=np.int64) if self.with_velocity:
ann_info['gt_bboxes_3d'] = np.zeros((0, 9), dtype=np.float32)
if 'centers_2d' in ann_info: else:
centers_2d = ann_info['centers_2d'][mask] ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
depths = ann_info['depths'][mask] ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
else:
centers_2d = np.zeros((0, 2), dtype=np.float32) if self.task == 'mono3d':
depths = np.zeros((0), dtype=np.float32) ann_info['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
ann_info['gt_bboxes_labels'] = np.array(0, dtype=np.int64)
if self.with_velocity: ann_info['attr_labels'] = np.array(0, dtype=np.int64)
gt_velocity = ann_info['velocity'][mask] ann_info['centers_2d'] = np.zeros((0, 2), dtype=np.float32)
nan_mask = np.isnan(gt_velocity[:, 0]) ann_info['depths'] = np.zeros((0), dtype=np.float32)
gt_velocity[nan_mask] = [0.0, 0.0]
gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1)
# the nuscenes box center is [0.5, 0.5, 0.5], we change it to be # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
# the same as KITTI (0.5, 0.5, 0) # the same as KITTI (0.5, 0.5, 0)
# TODO: Unify the coordinates # TODO: Unify the coordinates
if self.task == 'mono3d': if self.task == 'mono_det':
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(
gt_bboxes_3d, ann_info['gt_bboxes_3d'],
box_dim=gt_bboxes_3d.shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5)) origin=(0.5, 0.5, 0.5))
else: else:
gt_bboxes_3d = LiDARInstance3DBoxes( gt_bboxes_3d = LiDARInstance3DBoxes(
gt_bboxes_3d, ann_info['gt_bboxes_3d'],
box_dim=gt_bboxes_3d.shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d) origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
anns_results = dict( ann_info['gt_bboxes_3d'] = gt_bboxes_3d
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
attr_labels=attr_labels,
centers_2d=centers_2d,
depths=depths)
return anns_results return ann_info
def parse_data_info(self, info: dict) -> dict: def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info. """Process the raw data info.
...@@ -173,7 +180,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -173,7 +180,7 @@ class NuScenesDataset(Det3DDataset):
dict: Has `ann_info` in training stage. And dict: Has `ann_info` in training stage. And
all path has been converted to absolute path. all path has been converted to absolute path.
""" """
if self.task == 'mono3d': if self.task == 'mono_det':
data_list = [] data_list = []
if self.modality['use_lidar']: if self.modality['use_lidar']:
info['lidar_points']['lidar_path'] = \ info['lidar_points']['lidar_path'] = \
......
...@@ -36,7 +36,7 @@ class ScanNetDataset(Det3DDataset): ...@@ -36,7 +36,7 @@ class ScanNetDataset(Det3DDataset):
box_type_3d (str): Type of 3D box of this dataset. box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'Depth' in this dataset. Available options includes Defaults to 'Depth' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates. - 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset. - 'Depth': Box in depth coordinates, usually for indoor dataset.
...@@ -61,13 +61,13 @@ class ScanNetDataset(Det3DDataset): ...@@ -61,13 +61,13 @@ class ScanNetDataset(Det3DDataset):
def __init__(self, def __init__(self,
data_root: str, data_root: str,
ann_file: str, ann_file: str,
metainfo: dict = None, metainfo: Optional[dict] = None,
data_prefix: dict = dict( data_prefix: dict = dict(
pts='points', pts='points',
pts_instance_mask='instance_mask', pts_instance_mask='instance_mask',
pts_semantic_mask='semantic_mask'), pts_semantic_mask='semantic_mask'),
pipeline: List[Union[dict, Callable]] = [], pipeline: List[Union[dict, Callable]] = [],
modality=dict(use_camera=False, use_lidar=True), modality: dict = dict(use_camera=False, use_lidar=True),
box_type_3d: str = 'Depth', box_type_3d: str = 'Depth',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
...@@ -101,7 +101,7 @@ class ScanNetDataset(Det3DDataset): ...@@ -101,7 +101,7 @@ class ScanNetDataset(Det3DDataset):
assert self.modality['use_camera'] or self.modality['use_lidar'] assert self.modality['use_camera'] or self.modality['use_lidar']
@staticmethod @staticmethod
def _get_axis_align_matrix(info: dict) -> dict: def _get_axis_align_matrix(info: dict) -> np.ndarray:
"""Get axis_align_matrix from info. If not exist, return identity mat. """Get axis_align_matrix from info. If not exist, return identity mat.
Args: Args:
......
...@@ -24,25 +24,25 @@ class SUNRGBDDataset(Det3DDataset): ...@@ -24,25 +24,25 @@ class SUNRGBDDataset(Det3DDataset):
ann_file (str): Path of annotation file. ann_file (str): Path of annotation file.
metainfo (dict, optional): Meta information for dataset, such as class metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None. information. Defaults to None.
data_prefix (dict): Prefix for data. Defaults to data_prefix (dict, optiona;): Prefix for data. Defaults to
`dict(pts='points',img='sunrgbd_trainval')`. dict(pts='points',img='sunrgbd_trainval').
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to `dict(use_camera=True, use_lidar=True)`. as input. Defaults to dict(use_camera=True, use_lidar=True).
default_cam_key (str): The default camera name adopted. default_cam_key (str, optional): The default camera name adopted.
Defaults to "CAM0". Defaults to 'CAM0'.
box_type_3d (str): Type of 3D box of this dataset. box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'Depth' in this dataset. Available options includes Defaults to 'Depth' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates. - 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset. - 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates. - 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter empty GT. filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True. Defaults to True.
test_mode (bool): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
......
...@@ -11,11 +11,12 @@ from .test_time_aug import MultiScaleFlipAug3D ...@@ -11,11 +11,12 @@ from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (AffineResize, BackgroundPointsFilter, from .transforms_3d import (AffineResize, BackgroundPointsFilter,
GlobalAlignment, GlobalRotScaleTrans, GlobalAlignment, GlobalRotScaleTrans,
IndoorPatchPointSample, IndoorPointSample, IndoorPatchPointSample, IndoorPointSample,
ObjectNameFilter, ObjectNoise, ObjectRangeFilter, MultiViewWrapper, ObjectNameFilter, ObjectNoise,
ObjectSample, PointSample, PointShuffle, ObjectRangeFilter, ObjectSample,
PhotoMetricDistortion3D, PointSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomShiftScale, RandomFlip3D, RandomJitterPoints, RandomResize3D,
VoxelBasedPointSampler) RandomShiftScale, Resize3D, VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -29,5 +30,6 @@ __all__ = [ ...@@ -29,5 +30,6 @@ __all__ = [
'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample', 'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints', 'AffineResize', 'RandomShiftScale', 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale',
'LoadPointsFromDict' 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D'
] ]
...@@ -32,7 +32,7 @@ class Compose: ...@@ -32,7 +32,7 @@ class Compose:
data (dict): A result dict contains the data to transform. data (dict): A result dict contains the data to transform.
Returns: Returns:
dict: Transformed data. dict: Transformed data.
""" """
for t in self.transforms: for t in self.transforms:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import os import os
import warnings from typing import List, Optional
import mmengine import mmengine
import numpy as np import numpy as np
...@@ -16,18 +16,19 @@ class BatchSampler: ...@@ -16,18 +16,19 @@ class BatchSampler:
Args: Args:
sample_list (list[dict]): List of samples. sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Default: None. name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Default: None. epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices. Default: False. shuffle (bool, optional): Whether to shuffle indices.
drop_reminder (bool, optional): Drop reminder. Default: False. Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
""" """
def __init__(self, def __init__(self,
sampled_list, sampled_list: List[dict],
name=None, name: Optional[str] = None,
epoch=None, epoch: Optional[int] = None,
shuffle=True, shuffle: bool = True,
drop_reminder=False): drop_reminder: bool = False) -> None:
self._sampled_list = sampled_list self._sampled_list = sampled_list
self._indices = np.arange(len(sampled_list)) self._indices = np.arange(len(sampled_list))
if shuffle: if shuffle:
...@@ -40,7 +41,7 @@ class BatchSampler: ...@@ -40,7 +41,7 @@ class BatchSampler:
self._epoch_counter = 0 self._epoch_counter = 0
self._drop_reminder = drop_reminder self._drop_reminder = drop_reminder
def _sample(self, num): def _sample(self, num: int) -> List[int]:
"""Sample specific number of ground truths and return indices. """Sample specific number of ground truths and return indices.
Args: Args:
...@@ -57,7 +58,7 @@ class BatchSampler: ...@@ -57,7 +58,7 @@ class BatchSampler:
self._idx += num self._idx += num
return ret return ret
def _reset(self): def _reset(self) -> None:
"""Reset the index of batchsampler to zero.""" """Reset the index of batchsampler to zero."""
assert self._name is not None assert self._name is not None
# print("reset", self._name) # print("reset", self._name)
...@@ -65,7 +66,7 @@ class BatchSampler: ...@@ -65,7 +66,7 @@ class BatchSampler:
np.random.shuffle(self._indices) np.random.shuffle(self._indices)
self._idx = 0 self._idx = 0
def sample(self, num): def sample(self, num: int) -> List[dict]:
"""Sample specific number of ground truths. """Sample specific number of ground truths.
Args: Args:
...@@ -88,24 +89,30 @@ class DataBaseSampler(object): ...@@ -88,24 +89,30 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number. rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value. prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers. sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Default: None. classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict, optional): Config of points loader. Default: points_loader(dict, optional): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
""" """
def __init__(self, def __init__(
info_path, self,
data_root, info_path: str,
rate, data_root: str,
prepare, rate: float,
sample_groups, prepare: dict,
classes=None, sample_groups: dict,
points_loader=dict( classes: Optional[List[str]] = None,
type='LoadPointsFromFile', points_loader: dict = dict(
coord_type='LIDAR', type='LoadPointsFromFile',
load_dim=4, coord_type='LIDAR',
use_dim=[0, 1, 2, 3]), load_dim=4,
file_client_args=dict(backend='disk')): use_dim=[0, 1, 2, 3]),
file_client_args: dict = dict(backend='disk')
) -> None:
super().__init__() super().__init__()
self.data_root = data_root self.data_root = data_root
self.info_path = info_path self.info_path = info_path
...@@ -118,18 +125,9 @@ class DataBaseSampler(object): ...@@ -118,18 +125,9 @@ class DataBaseSampler(object):
self.file_client = mmengine.FileClient(**file_client_args) self.file_client = mmengine.FileClient(**file_client_args)
# load data base infos # load data base infos
if hasattr(self.file_client, 'get_local_path'): with self.file_client.get_local_path(info_path) as local_path:
with self.file_client.get_local_path(info_path) as local_path: # loading data from a file-like object needs file format
# loading data from a file-like object needs file format db_infos = mmengine.load(open(local_path, 'rb'), file_format='pkl')
db_infos = mmengine.load(
open(local_path, 'rb'), file_format='pkl')
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {info_path} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
db_infos = mmengine.load(info_path)
# filter database infos # filter database infos
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
...@@ -163,7 +161,7 @@ class DataBaseSampler(object): ...@@ -163,7 +161,7 @@ class DataBaseSampler(object):
# TODO: No group_sampling currently # TODO: No group_sampling currently
@staticmethod @staticmethod
def filter_by_difficulty(db_infos, removed_difficulty): def filter_by_difficulty(db_infos: dict, removed_difficulty: list) -> dict:
"""Filter ground truths by difficulties. """Filter ground truths by difficulties.
Args: Args:
...@@ -182,7 +180,7 @@ class DataBaseSampler(object): ...@@ -182,7 +180,7 @@ class DataBaseSampler(object):
return new_db_infos return new_db_infos
@staticmethod @staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict): def filter_by_min_points(db_infos: dict, min_gt_points_dict: dict) -> dict:
"""Filter ground truths by number of points in the bbox. """Filter ground truths by number of points in the bbox.
Args: Args:
...@@ -203,12 +201,19 @@ class DataBaseSampler(object): ...@@ -203,12 +201,19 @@ class DataBaseSampler(object):
db_infos[name] = filtered_infos db_infos[name] = filtered_infos
return db_infos return db_infos
def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None): def sample_all(self,
gt_bboxes: np.ndarray,
gt_labels: np.ndarray,
img: Optional[np.ndarray] = None,
ground_plane: Optional[np.ndarray] = None) -> dict:
"""Sampling all categories of bboxes. """Sampling all categories of bboxes.
Args: Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes. gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Ground truth labels of boxes. gt_labels (np.ndarray): Ground truth labels of boxes.
img (np.ndarray, optional): Image array. Defaults to None.
ground_plane (np.ndarray, optional): Ground plane information.
Defaults to None.
Returns: Returns:
dict: Dict of sampled 'pseudo ground truths'. dict: Dict of sampled 'pseudo ground truths'.
...@@ -301,7 +306,8 @@ class DataBaseSampler(object): ...@@ -301,7 +306,8 @@ class DataBaseSampler(object):
return ret return ret
def sample_class_v2(self, name, num, gt_bboxes): def sample_class_v2(self, name: str, num: int,
gt_bboxes: np.ndarray) -> List[dict]:
"""Sampling specific categories of bounding boxes. """Sampling specific categories of bounding boxes.
Args: Args:
......
...@@ -63,15 +63,20 @@ class Pack3DDetInputs(BaseTransform): ...@@ -63,15 +63,20 @@ class Pack3DDetInputs(BaseTransform):
def __init__( def __init__(
self, self,
keys: dict, keys: tuple,
meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'depth2img', 'cam2img', 'pad_shape',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'scale_factor', 'flip', 'pcd_horizontal_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'pcd_trans', 'sample_idx', 'pcd_scale_factor', 'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
'pcd_rotation', 'pcd_rotation_angle', 'lidar_path', 'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'transformation_3d_flow', 'trans_mat', 'pcd_rotation_angle', 'lidar_path',
'affine_aug')): 'transformation_3d_flow', 'trans_mat',
'affine_aug', 'sweep_img_metas', 'ori_cam2img',
'cam2global', 'crop_offset', 'img_crop_offset',
'resize_img_shape', 'lidar2cam', 'ori_lidar2img',
'num_ref_frames', 'num_views', 'ego2global')
) -> None:
self.keys = keys self.keys = keys
self.meta_keys = meta_keys self.meta_keys = meta_keys
...@@ -98,7 +103,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -98,7 +103,7 @@ class Pack3DDetInputs(BaseTransform):
- img - img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info of - 'data_samples' (obj:`Det3DDataSample`): The annotation info of
the sample. the sample.
""" """
# augtest # augtest
if isinstance(results, list): if isinstance(results, list):
...@@ -115,7 +120,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -115,7 +120,7 @@ class Pack3DDetInputs(BaseTransform):
else: else:
raise NotImplementedError raise NotImplementedError
def pack_single_results(self, results): def pack_single_results(self, results: dict) -> dict:
"""Method to pack the single input data. when the value in this dict is """Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing. a list, it usually is in Augmentations Testing.
...@@ -131,7 +136,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -131,7 +136,7 @@ class Pack3DDetInputs(BaseTransform):
- points - points
- img - img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info - 'data_samples' (:obj:`Det3DDataSample`): The annotation info
of the sample. of the sample.
""" """
# Format 3D data # Format 3D data
...@@ -219,6 +224,7 @@ class Pack3DDetInputs(BaseTransform): ...@@ -219,6 +224,7 @@ class Pack3DDetInputs(BaseTransform):
return packed_results return packed_results
def __repr__(self) -> str: def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys})' repr_str += f'(keys={self.keys})'
repr_str += f'(meta_keys={self.meta_keys})' repr_str += f'(meta_keys={self.meta_keys})'
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List import copy
from typing import List, Optional, Union
import mmcv import mmcv
import mmengine import mmengine
...@@ -13,7 +14,7 @@ from mmdet.datasets.transforms import LoadAnnotations ...@@ -13,7 +14,7 @@ from mmdet.datasets.transforms import LoadAnnotations
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class LoadMultiViewImageFromFiles(object): class LoadMultiViewImageFromFiles(BaseTransform):
"""Load multi channel images from a list of separate channel files. """Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames. Expects results['img_filename'] to be a list of filenames.
...@@ -23,13 +24,38 @@ class LoadMultiViewImageFromFiles(object): ...@@ -23,13 +24,38 @@ class LoadMultiViewImageFromFiles(object):
Defaults to False. Defaults to False.
color_type (str, optional): Color type of the file. color_type (str, optional): Color type of the file.
Defaults to 'unchanged'. Defaults to 'unchanged'.
file_client_args (dict): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
num_views (int): num of view in a frame. Default to 5.
num_ref_frames (int): num of frame in loading. Default to -1.
test_mode (bool): Whether is test mode in loading. Default to False.
set_default_scale (bool): Whether to set default scale. Default to
True.
""" """
def __init__(self, to_float32=False, color_type='unchanged'): def __init__(self,
to_float32: bool = False,
color_type: str = 'unchanged',
file_client_args: dict = dict(backend='disk'),
num_views: int = 5,
num_ref_frames: int = -1,
test_mode: bool = False,
set_default_scale: bool = True) -> None:
self.to_float32 = to_float32 self.to_float32 = to_float32
self.color_type = color_type self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
self.num_views = num_views
# num_ref_frames is used for multi-sweep loading
self.num_ref_frames = num_ref_frames
# when test_mode=False, we randomly select previous frames
# otherwise, select the earliest one
self.test_mode = test_mode
self.set_default_scale = set_default_scale
def __call__(self, results): def transform(self, results: dict) -> Optional[dict]:
"""Call function to load multi-view image from files. """Call function to load multi-view image from files.
Args: Args:
...@@ -47,33 +73,151 @@ class LoadMultiViewImageFromFiles(object): ...@@ -47,33 +73,151 @@ class LoadMultiViewImageFromFiles(object):
- scale_factor (float): Scale factor. - scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images. - img_norm_cfg (dict): Normalization configuration of images.
""" """
filename = results['img_filename'] # TODO: consider split the multi-sweep part out of this pipeline
# Derive the mask and transform for loading of multi-sweep data
if self.num_ref_frames > 0:
# init choice with the current frame
init_choice = np.array([0], dtype=np.int64)
num_frames = len(results['img_filename']) // self.num_views - 1
if num_frames == 0: # no previous frame, then copy cur frames
choices = np.random.choice(
1, self.num_ref_frames, replace=True)
elif num_frames >= self.num_ref_frames:
# NOTE: suppose the info is saved following the order
# from latest to earlier frames
if self.test_mode:
choices = np.arange(num_frames - self.num_ref_frames,
num_frames) + 1
# NOTE: +1 is for selecting previous frames
else:
choices = np.random.choice(
num_frames, self.num_ref_frames, replace=False) + 1
elif num_frames > 0 and num_frames < self.num_ref_frames:
if self.test_mode:
base_choices = np.arange(num_frames) + 1
random_choices = np.random.choice(
num_frames,
self.num_ref_frames - num_frames,
replace=True) + 1
choices = np.concatenate([base_choices, random_choices])
else:
choices = np.random.choice(
num_frames, self.num_ref_frames, replace=True) + 1
else:
raise NotImplementedError
choices = np.concatenate([init_choice, choices])
select_filename = []
for choice in choices:
select_filename += results['img_filename'][choice *
self.num_views:
(choice + 1) *
self.num_views]
results['img_filename'] = select_filename
for key in ['cam2img', 'lidar2cam']:
if key in results:
select_results = []
for choice in choices:
select_results += results[key][choice *
self.num_views:(choice +
1) *
self.num_views]
results[key] = select_results
for key in ['ego2global']:
if key in results:
select_results = []
for choice in choices:
select_results += [results[key][choice]]
results[key] = select_results
# Transform lidar2cam to
# [cur_lidar]2[prev_img] and [cur_lidar]2[prev_cam]
for key in ['lidar2cam']:
if key in results:
# only change matrices of previous frames
for choice_idx in range(1, len(choices)):
pad_prev_ego2global = np.eye(4)
prev_ego2global = results['ego2global'][choice_idx]
pad_prev_ego2global[:prev_ego2global.
shape[0], :prev_ego2global.
shape[1]] = prev_ego2global
pad_cur_ego2global = np.eye(4)
cur_ego2global = results['ego2global'][0]
pad_cur_ego2global[:cur_ego2global.
shape[0], :cur_ego2global.
shape[1]] = cur_ego2global
cur2prev = np.linalg.inv(pad_prev_ego2global).dot(
pad_cur_ego2global)
for result_idx in range(choice_idx * self.num_views,
(choice_idx + 1) *
self.num_views):
results[key][result_idx] = \
results[key][result_idx].dot(cur2prev)
# Support multi-view images with different shapes
# TODO: record the origin shape and padded shape
filename, cam2img, lidar2cam = [], [], []
for _, cam_item in results['images'].items():
filename.append(cam_item['img_path'])
cam2img.append(cam_item['cam2img'])
lidar2cam.append(cam_item['lidar2cam'])
results['filename'] = filename
results['cam2img'] = cam2img
results['lidar2cam'] = lidar2cam
results['ori_cam2img'] = copy.deepcopy(results['cam2img'])
if self.file_client is None:
self.file_client = mmengine.FileClient(**self.file_client_args)
# img is of shape (h, w, c, num_views) # img is of shape (h, w, c, num_views)
img = np.stack( # h and w can be different for different views
[mmcv.imread(name, self.color_type) for name in filename], axis=-1) img_bytes = [self.file_client.get(name) for name in filename]
imgs = [
mmcv.imfrombytes(img_byte, flag=self.color_type)
for img_byte in img_bytes
]
# handle the image with different shape
img_shapes = np.stack([img.shape for img in imgs], axis=0)
img_shape_max = np.max(img_shapes, axis=0)
img_shape_min = np.min(img_shapes, axis=0)
assert img_shape_min[-1] == img_shape_max[-1]
if not np.all(img_shape_max == img_shape_min):
pad_shape = img_shape_max[:2]
else:
pad_shape = None
if pad_shape is not None:
imgs = [
mmcv.impad(img, shape=pad_shape, pad_val=0) for img in imgs
]
img = np.stack(imgs, axis=-1)
if self.to_float32: if self.to_float32:
img = img.astype(np.float32) img = img.astype(np.float32)
results['filename'] = filename results['filename'] = filename
# unravel to list, see `DefaultFormatBundle` in formatting.py # unravel to list, see `DefaultFormatBundle` in formating.py
# which will transpose each image separately and then stack into array # which will transpose each image separately and then stack into array
results['img'] = [img[..., i] for i in range(img.shape[-1])] results['img'] = [img[..., i] for i in range(img.shape[-1])]
results['img_shape'] = img.shape results['img_shape'] = img.shape
results['ori_shape'] = img.shape results['ori_shape'] = img.shape
# Set initial values for default meta_keys # Set initial values for default meta_keys
results['pad_shape'] = img.shape results['pad_shape'] = img.shape
results['scale_factor'] = 1.0 if self.set_default_scale:
results['scale_factor'] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2] num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results['img_norm_cfg'] = dict( results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32), mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32),
to_rgb=False) to_rgb=False)
results['num_views'] = self.num_views
results['num_ref_frames'] = self.num_ref_frames
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(to_float32={self.to_float32}, ' repr_str += f'(to_float32={self.to_float32}, '
repr_str += f"color_type='{self.color_type}')" repr_str += f"color_type='{self.color_type}', "
repr_str += f'num_views={self.num_views}, '
repr_str += f'num_ref_frames={self.num_ref_frames}, '
repr_str += f'test_mode={self.test_mode})'
return repr_str return repr_str
...@@ -139,7 +283,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -139,7 +283,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to [0, 1, 2, 4]. Defaults to [0, 1, 2, 4].
file_client_args (dict, optional): Config dict of file clients, file_client_args (dict, optional): Config dict of file clients,
refer to refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool, optional): Whether to repeat keyframe when pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps is empty. Defaults to False. sweeps is empty. Defaults to False.
...@@ -151,13 +295,13 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -151,13 +295,13 @@ class LoadPointsFromMultiSweeps(BaseTransform):
""" """
def __init__(self, def __init__(self,
sweeps_num=10, sweeps_num: int = 10,
load_dim=5, load_dim: int = 5,
use_dim=[0, 1, 2, 4], use_dim: List[int] = [0, 1, 2, 4],
file_client_args=dict(backend='disk'), file_client_args: dict = dict(backend='disk'),
pad_empty_sweeps=False, pad_empty_sweeps: bool = False,
remove_close=False, remove_close: bool = False,
test_mode=False): test_mode: bool = False) -> None:
self.load_dim = load_dim self.load_dim = load_dim
self.sweeps_num = sweeps_num self.sweeps_num = sweeps_num
self.use_dim = use_dim self.use_dim = use_dim
...@@ -167,7 +311,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -167,7 +311,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
self.remove_close = remove_close self.remove_close = remove_close
self.test_mode = test_mode self.test_mode = test_mode
def _load_points(self, pts_filename): def _load_points(self, pts_filename: str) -> np.ndarray:
"""Private function to load point clouds data. """Private function to load point clouds data.
Args: Args:
...@@ -189,7 +333,9 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -189,7 +333,9 @@ class LoadPointsFromMultiSweeps(BaseTransform):
points = np.fromfile(pts_filename, dtype=np.float32) points = np.fromfile(pts_filename, dtype=np.float32)
return points return points
def _remove_close(self, points, radius=1.0): def _remove_close(self,
points: Union[np.ndarray, BasePoints],
radius: float = 1.0) -> Union[np.ndarray, BasePoints]:
"""Removes point too close within a certain radius from origin. """Removes point too close within a certain radius from origin.
Args: Args:
...@@ -198,7 +344,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -198,7 +344,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
np.ndarray: Points after removing. np.ndarray | :obj:`BasePoints`: Points after removing.
""" """
if isinstance(points, np.ndarray): if isinstance(points, np.ndarray):
points_numpy = points points_numpy = points
...@@ -211,7 +357,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -211,7 +357,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
not_close = np.logical_not(np.logical_and(x_filt, y_filt)) not_close = np.logical_not(np.logical_and(x_filt, y_filt))
return points[not_close] return points[not_close]
def transform(self, results): def transform(self, results: dict) -> dict:
"""Call function to load multi-sweep point clouds from files. """Call function to load multi-sweep point clouds from files.
Args: Args:
...@@ -220,7 +366,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): ...@@ -220,7 +366,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Returns: Returns:
dict: The result dict containing the multi-sweep points data. dict: The result dict containing the multi-sweep points data.
Added key and value are described below. Updated key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays. cloud arrays.
...@@ -290,7 +436,7 @@ class PointSegClassMapping(BaseTransform): ...@@ -290,7 +436,7 @@ class PointSegClassMapping(BaseTransform):
others as len(valid_cat_ids). others as len(valid_cat_ids).
""" """
def transform(self, results: dict) -> None: def transform(self, results: dict) -> dict:
"""Call function to map original semantic class to valid category ids. """Call function to map original semantic class to valid category ids.
Args: Args:
...@@ -322,8 +468,6 @@ class PointSegClassMapping(BaseTransform): ...@@ -322,8 +468,6 @@ class PointSegClassMapping(BaseTransform):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(valid_cat_ids={self.valid_cat_ids}, '
repr_str += f'max_cat_id={self.max_cat_id})'
return repr_str return repr_str
...@@ -385,13 +529,14 @@ class LoadPointsFromFile(BaseTransform): ...@@ -385,13 +529,14 @@ class LoadPointsFromFile(BaseTransform):
Args: Args:
coord_type (str): The type of coordinates of points cloud. coord_type (str): The type of coordinates of points cloud.
Available options includes: Available options includes:
- 'LIDAR': Points in LiDAR coordinates. - 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset. - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates. - 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points. load_dim (int, optional): The dimension of the loaded points.
Defaults to 6. Defaults to 6.
use_dim (list[int], optional): Which dimensions of the points to use. use_dim (list[int] | int, optional): Which dimensions of the points
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension. or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height. shift_height (bool, optional): Whether to use shifted height.
Defaults to False. Defaults to False.
...@@ -399,7 +544,7 @@ class LoadPointsFromFile(BaseTransform): ...@@ -399,7 +544,7 @@ class LoadPointsFromFile(BaseTransform):
Defaults to False. Defaults to False.
file_client_args (dict, optional): Config dict of file clients, file_client_args (dict, optional): Config dict of file clients,
refer to refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
""" """
...@@ -407,7 +552,7 @@ class LoadPointsFromFile(BaseTransform): ...@@ -407,7 +552,7 @@ class LoadPointsFromFile(BaseTransform):
self, self,
coord_type: str, coord_type: str,
load_dim: int = 6, load_dim: int = 6,
use_dim: list = [0, 1, 2], use_dim: Union[int, List[int]] = [0, 1, 2],
shift_height: bool = False, shift_height: bool = False,
use_color: bool = False, use_color: bool = False,
file_client_args: dict = dict(backend='disk') file_client_args: dict = dict(backend='disk')
...@@ -523,6 +668,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -523,6 +668,7 @@ class LoadAnnotations3D(LoadAnnotations):
Required Keys: Required Keys:
- ann_info (dict) - ann_info (dict)
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` | - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` |
:obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`): :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`):
3D ground truth bboxes. Only when `with_bbox_3d` is True 3D ground truth bboxes. Only when `with_bbox_3d` is True
...@@ -592,7 +738,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -592,7 +738,7 @@ class LoadAnnotations3D(LoadAnnotations):
seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks. seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
Defaults to int64. Defaults to int64.
file_client_args (dict): Config dict of file clients, refer to file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. for more details.
""" """
......
...@@ -16,7 +16,7 @@ class MultiScaleFlipAug3D(BaseTransform): ...@@ -16,7 +16,7 @@ class MultiScaleFlipAug3D(BaseTransform):
Args: Args:
transforms (list[dict]): Transforms to apply in each augmentation. transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing. img_scale (tuple | list[tuple]): Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for pts_scale_ratio (float | list[float]): Points scale ratios for
resizing. resizing.
flip (bool, optional): Whether apply flip augmentation. flip (bool, optional): Whether apply flip augmentation.
...@@ -25,11 +25,11 @@ class MultiScaleFlipAug3D(BaseTransform): ...@@ -25,11 +25,11 @@ class MultiScaleFlipAug3D(BaseTransform):
directions for images, options are "horizontal" and "vertical". directions for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``. be applied. It has no effect when ``flip == False``.
Defaults to "horizontal". Defaults to 'horizontal'.
pcd_horizontal_flip (bool, optional): Whether apply horizontal pcd_horizontal_flip (bool, optional): Whether to apply horizontal
flip augmentation to point cloud. Defaults to True. flip augmentation to point cloud. Defaults to True.
Note that it works only when 'flip' is turned on. Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool, optional): Whether apply vertical flip pcd_vertical_flip (bool, optional): Whether to apply vertical flip
augmentation to point cloud. Defaults to True. augmentation to point cloud. Defaults to True.
Note that it works only when 'flip' is turned on. Note that it works only when 'flip' is turned on.
""" """
...@@ -46,7 +46,7 @@ class MultiScaleFlipAug3D(BaseTransform): ...@@ -46,7 +46,7 @@ class MultiScaleFlipAug3D(BaseTransform):
self.img_scale = img_scale if isinstance(img_scale, self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale] list) else [img_scale]
self.pts_scale_ratio = pts_scale_ratio \ self.pts_scale_ratio = pts_scale_ratio \
if isinstance(pts_scale_ratio, list) else[float(pts_scale_ratio)] if isinstance(pts_scale_ratio, list) else [float(pts_scale_ratio)]
assert mmengine.is_list_of(self.img_scale, tuple) assert mmengine.is_list_of(self.img_scale, tuple)
assert mmengine.is_list_of(self.pts_scale_ratio, float) assert mmengine.is_list_of(self.pts_scale_ratio, float)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import cv2 import cv2
import mmcv
import numpy as np import numpy as np
from mmcv.transforms import BaseTransform from mmcv.transforms import BaseTransform, RandomResize, Resize
from mmengine import is_tuple_of from mmengine import is_tuple_of
from mmdet3d.models.task_modules import VoxelGenerator from mmdet3d.models.task_modules import VoxelGenerator
...@@ -14,7 +15,9 @@ from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes, ...@@ -14,7 +15,9 @@ from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes) LiDARInstance3DBoxes)
from mmdet3d.structures.ops import box_np_ops from mmdet3d.structures.ops import box_np_ops
from mmdet3d.structures.points import BasePoints from mmdet3d.structures.points import BasePoints
from mmdet.datasets.transforms import RandomFlip from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop,
RandomFlip)
from .compose import Compose
from .data_augment_utils import noise_per_object_v3_ from .data_augment_utils import noise_per_object_v3_
...@@ -76,7 +79,6 @@ class RandomFlip3D(RandomFlip): ...@@ -76,7 +79,6 @@ class RandomFlip3D(RandomFlip):
otherwise it will be randomly decided by a ratio specified in the init otherwise it will be randomly decided by a ratio specified in the init
method. method.
Required Keys: Required Keys:
- points (np.float32) - points (np.float32)
...@@ -96,20 +98,25 @@ class RandomFlip3D(RandomFlip): ...@@ -96,20 +98,25 @@ class RandomFlip3D(RandomFlip):
- pcd_scale_factor (np.float32) - pcd_scale_factor (np.float32)
Args: Args:
sync_2d (bool, optional): Whether to apply flip according to the 2D sync_2d (bool): Whether to apply flip according to the 2D
images. If True, it will apply the same flip as that to 2D images. images. If True, it will apply the same flip as that to 2D images.
If False, it will decide whether to flip randomly and independently If False, it will decide whether to flip randomly and independently
to that of 2D images. Defaults to True. to that of 2D images. Defaults to True.
flip_ratio_bev_horizontal (float, optional): The flipping probability flip_ratio_bev_horizontal (float): The flipping probability
in horizontal direction. Defaults to 0.0. in horizontal direction. Defaults to 0.0.
flip_ratio_bev_vertical (float, optional): The flipping probability flip_ratio_bev_vertical (float): The flipping probability
in vertical direction. Defaults to 0.0. in vertical direction. Defaults to 0.0.
flip_box3d (bool): Whether to flip bounding box. In most of the case,
the box should be fliped. In cam-based bev detection, this is set
to false, since the flip of 2D images does not influence the 3D
box. Default to True.
""" """
def __init__(self, def __init__(self,
sync_2d: bool = True, sync_2d: bool = True,
flip_ratio_bev_horizontal: float = 0.0, flip_ratio_bev_horizontal: float = 0.0,
flip_ratio_bev_vertical: float = 0.0, flip_ratio_bev_vertical: float = 0.0,
flip_box3d: bool = True,
**kwargs) -> None: **kwargs) -> None:
# `flip_ratio_bev_horizontal` is equal to # `flip_ratio_bev_horizontal` is equal to
# for flip prob of 2d image when # for flip prob of 2d image when
...@@ -119,6 +126,7 @@ class RandomFlip3D(RandomFlip): ...@@ -119,6 +126,7 @@ class RandomFlip3D(RandomFlip):
self.sync_2d = sync_2d self.sync_2d = sync_2d
self.flip_ratio_bev_horizontal = flip_ratio_bev_horizontal self.flip_ratio_bev_horizontal = flip_ratio_bev_horizontal
self.flip_ratio_bev_vertical = flip_ratio_bev_vertical self.flip_ratio_bev_vertical = flip_ratio_bev_vertical
self.flip_box3d = flip_box3d
if flip_ratio_bev_horizontal is not None: if flip_ratio_bev_horizontal is not None:
assert isinstance( assert isinstance(
flip_ratio_bev_horizontal, flip_ratio_bev_horizontal,
...@@ -150,23 +158,21 @@ class RandomFlip3D(RandomFlip): ...@@ -150,23 +158,21 @@ class RandomFlip3D(RandomFlip):
updated in the result dict. updated in the result dict.
""" """
assert direction in ['horizontal', 'vertical'] assert direction in ['horizontal', 'vertical']
if self.flip_box3d:
if 'gt_bboxes_3d' in input_dict: if 'gt_bboxes_3d' in input_dict:
if 'points' in input_dict: if 'points' in input_dict:
input_dict['points'] = input_dict['gt_bboxes_3d'].flip( input_dict['points'] = input_dict['gt_bboxes_3d'].flip(
direction, points=input_dict['points']) direction, points=input_dict['points'])
else:
# vision-only detection
input_dict['gt_bboxes_3d'].flip(direction)
else: else:
# vision-only detection input_dict['points'].flip(direction)
input_dict['gt_bboxes_3d'].flip(direction)
else:
input_dict['points'].flip(direction)
if 'centers_2d' in input_dict: if 'centers_2d' in input_dict:
assert self.sync_2d is True and direction == 'horizontal', \ assert self.sync_2d is True and direction == 'horizontal', \
'Only support sync_2d=True and horizontal flip with images' 'Only support sync_2d=True and horizontal flip with images'
# TODO fix this ori_shape and other keys in vision based model w = input_dict['img_shape'][1]
# TODO ori_shape to img_shape
w = input_dict['ori_shape'][1]
input_dict['centers_2d'][..., 0] = \ input_dict['centers_2d'][..., 0] = \
w - input_dict['centers_2d'][..., 0] w - input_dict['centers_2d'][..., 0]
# need to modify the horizontal position of camera center # need to modify the horizontal position of camera center
...@@ -176,6 +182,25 @@ class RandomFlip3D(RandomFlip): ...@@ -176,6 +182,25 @@ class RandomFlip3D(RandomFlip):
# https://github.com/open-mmlab/mmdetection3d/pull/744 # https://github.com/open-mmlab/mmdetection3d/pull/744
input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2] input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2]
def _flip_on_direction(self, results: dict) -> None:
"""Function to flip images, bounding boxes, semantic segmentation map
and keypoints.
Add the override feature that if 'flip' is already in results, use it
to do the augmentation.
"""
if 'flip' not in results:
cur_dir = self._choose_direction()
else:
cur_dir = results['flip_direction']
if cur_dir is None:
results['flip'] = False
results['flip_direction'] = None
else:
results['flip'] = True
results['flip_direction'] = cur_dir
self._flip(results)
def transform(self, input_dict: dict) -> dict: def transform(self, input_dict: dict) -> dict:
"""Call function to flip points, values in the ``bbox3d_fields`` and """Call function to flip points, values in the ``bbox3d_fields`` and
also flip 2D image and its annotations. also flip 2D image and its annotations.
...@@ -329,7 +354,7 @@ class ObjectSample(BaseTransform): ...@@ -329,7 +354,7 @@ class ObjectSample(BaseTransform):
def __init__(self, def __init__(self,
db_sampler: dict, db_sampler: dict,
sample_2d: bool = False, sample_2d: bool = False,
use_ground_plane: bool = False): use_ground_plane: bool = False) -> None:
self.sampler_cfg = db_sampler self.sampler_cfg = db_sampler
self.sample_2d = sample_2d self.sample_2d = sample_2d
if 'type' not in db_sampler.keys(): if 'type' not in db_sampler.keys():
...@@ -367,11 +392,10 @@ class ObjectSample(BaseTransform): ...@@ -367,11 +392,10 @@ class ObjectSample(BaseTransform):
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
if self.use_ground_plane and 'plane' in input_dict['ann_info']: if self.use_ground_plane:
ground_plane = input_dict['plane'] ground_plane = input_dict.get('plane', None)
assert ground_plane is not None, '`use_ground_plane` is True ' \ assert ground_plane is not None, '`use_ground_plane` is True ' \
'but find plane is None' 'but find plane is None'
input_dict['plane'] = ground_plane
else: else:
ground_plane = None ground_plane = None
# change to float for blending operation # change to float for blending operation
...@@ -424,13 +448,9 @@ class ObjectSample(BaseTransform): ...@@ -424,13 +448,9 @@ class ObjectSample(BaseTransform):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'db_sampler={self.db_sampler},'
repr_str += f' sample_2d={self.sample_2d},' repr_str += f' sample_2d={self.sample_2d},'
repr_str += f' data_root={self.sampler_cfg.data_root},' repr_str += f' use_ground_plane={self.use_ground_plane}'
repr_str += f' info_path={self.sampler_cfg.info_path},'
repr_str += f' rate={self.sampler_cfg.rate},'
repr_str += f' prepare={self.sampler_cfg.prepare},'
repr_str += f' classes={self.sampler_cfg.classes},'
repr_str += f' sample_groups={self.sampler_cfg.sample_groups}'
return repr_str return repr_str
...@@ -461,10 +481,10 @@ class ObjectNoise(BaseTransform): ...@@ -461,10 +481,10 @@ class ObjectNoise(BaseTransform):
""" """
def __init__(self, def __init__(self,
translation_std: list = [0.25, 0.25, 0.25], translation_std: List[float] = [0.25, 0.25, 0.25],
global_rot_range: list = [0.0, 0.0], global_rot_range: List[float] = [0.0, 0.0],
rot_range: list = [-0.15707963267, 0.15707963267], rot_range: List[float] = [-0.15707963267, 0.15707963267],
num_try: int = 100): num_try: int = 100) -> None:
self.translation_std = translation_std self.translation_std = translation_std
self.global_rot_range = global_rot_range self.global_rot_range = global_rot_range
self.rot_range = rot_range self.rot_range = rot_range
...@@ -527,7 +547,7 @@ class GlobalAlignment(BaseTransform): ...@@ -527,7 +547,7 @@ class GlobalAlignment(BaseTransform):
def __init__(self, rotation_axis: int) -> None: def __init__(self, rotation_axis: int) -> None:
self.rotation_axis = rotation_axis self.rotation_axis = rotation_axis
def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None: def _trans_points(self, results: dict, trans_factor: np.ndarray) -> None:
"""Private function to translate points. """Private function to translate points.
Args: Args:
...@@ -539,7 +559,7 @@ class GlobalAlignment(BaseTransform): ...@@ -539,7 +559,7 @@ class GlobalAlignment(BaseTransform):
""" """
results['points'].translate(trans_factor) results['points'].translate(trans_factor)
def _rot_points(self, results: Dict, rot_mat: np.ndarray) -> None: def _rot_points(self, results: dict, rot_mat: np.ndarray) -> None:
"""Private function to rotate bounding boxes and points. """Private function to rotate bounding boxes and points.
Args: Args:
...@@ -565,7 +585,7 @@ class GlobalAlignment(BaseTransform): ...@@ -565,7 +585,7 @@ class GlobalAlignment(BaseTransform):
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all() is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}' assert is_valid, f'invalid rotation matrix {rot_mat}'
def transform(self, results: Dict) -> Dict: def transform(self, results: dict) -> dict:
"""Call function to shuffle points. """Call function to shuffle points.
Args: Args:
...@@ -591,6 +611,7 @@ class GlobalAlignment(BaseTransform): ...@@ -591,6 +611,7 @@ class GlobalAlignment(BaseTransform):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(rotation_axis={self.rotation_axis})' repr_str += f'(rotation_axis={self.rotation_axis})'
return repr_str return repr_str
...@@ -809,6 +830,7 @@ class PointShuffle(BaseTransform): ...@@ -809,6 +830,7 @@ class PointShuffle(BaseTransform):
return input_dict return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
return self.__class__.__name__ return self.__class__.__name__
...@@ -828,7 +850,7 @@ class ObjectRangeFilter(BaseTransform): ...@@ -828,7 +850,7 @@ class ObjectRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range. point_cloud_range (list[float]): Point cloud range.
""" """
def __init__(self, point_cloud_range: list): def __init__(self, point_cloud_range: List[float]):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict: def transform(self, input_dict: dict) -> dict:
...@@ -890,7 +912,7 @@ class PointsRangeFilter(BaseTransform): ...@@ -890,7 +912,7 @@ class PointsRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range. point_cloud_range (list[float]): Point cloud range.
""" """
def __init__(self, point_cloud_range: list): def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict: def transform(self, input_dict: dict) -> dict:
...@@ -943,7 +965,7 @@ class ObjectNameFilter(BaseTransform): ...@@ -943,7 +965,7 @@ class ObjectNameFilter(BaseTransform):
classes (list[str]): List of class names to be kept for training. classes (list[str]): List of class names to be kept for training.
""" """
def __init__(self, classes: list): def __init__(self, classes: List[str]) -> None:
self.classes = classes self.classes = classes
self.labels = list(range(len(self.classes))) self.labels = list(range(len(self.classes)))
...@@ -1001,34 +1023,38 @@ class PointSample(BaseTransform): ...@@ -1001,34 +1023,38 @@ class PointSample(BaseTransform):
def __init__(self, def __init__(self,
num_points: int, num_points: int,
sample_range: float = None, sample_range: Optional[float] = None,
replace: bool = False): replace: bool = False) -> None:
self.num_points = num_points self.num_points = num_points
self.sample_range = sample_range self.sample_range = sample_range
self.replace = replace self.replace = replace
def _points_random_sampling(self, def _points_random_sampling(
points, self,
num_samples, points: BasePoints,
sample_range=None, num_samples: int,
replace=False, sample_range: Optional[float] = None,
return_choices=False): replace: bool = False,
return_choices: bool = False
) -> Union[Tuple[BasePoints, np.ndarray], BasePoints]:
"""Points random sampling. """Points random sampling.
Sample points to a certain number. Sample points to a certain number.
Args: Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points. points (:obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled. num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None. points will be sampled. Defaults to None.
replace (bool, optional): Sampling with or without replacement. replace (bool, optional): Sampling with or without replacement.
Defaults to None. Defaults to False.
return_choices (bool, optional): Whether return choice. return_choices (bool, optional): Whether return choice.
Defaults to False. Defaults to False.
Returns: Returns:
tuple[np.ndarray] | np.ndarray: tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
- points (np.ndarray | :obj:`BasePoints`): 3D Points.
- points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples. - choices (np.ndarray, optional): The generated random samples.
""" """
if not replace: if not replace:
...@@ -1036,7 +1062,7 @@ class PointSample(BaseTransform): ...@@ -1036,7 +1062,7 @@ class PointSample(BaseTransform):
point_range = range(len(points)) point_range = range(len(points))
if sample_range is not None and not replace: if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples # Only sampling the near points when len(points) >= num_samples
dist = np.linalg.norm(points.tensor, axis=1) dist = np.linalg.norm(points.coord.numpy(), axis=1)
far_inds = np.where(dist >= sample_range)[0] far_inds = np.where(dist >= sample_range)[0]
near_inds = np.where(dist < sample_range)[0] near_inds = np.where(dist < sample_range)[0]
# in case there are too many far points # in case there are too many far points
...@@ -1060,6 +1086,7 @@ class PointSample(BaseTransform): ...@@ -1060,6 +1086,7 @@ class PointSample(BaseTransform):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
...@@ -1219,8 +1246,9 @@ class IndoorPatchPointSample(BaseTransform): ...@@ -1219,8 +1246,9 @@ class IndoorPatchPointSample(BaseTransform):
return points return points
def _patch_points_sampling(self, points: BasePoints, def _patch_points_sampling(
sem_mask: np.ndarray) -> BasePoints: self, points: BasePoints,
sem_mask: np.ndarray) -> Tuple[BasePoints, np.ndarray]:
"""Patch points sampling. """Patch points sampling.
First sample a valid patch. First sample a valid patch.
...@@ -1231,7 +1259,7 @@ class IndoorPatchPointSample(BaseTransform): ...@@ -1231,7 +1259,7 @@ class IndoorPatchPointSample(BaseTransform):
sem_mask (np.ndarray): semantic segmentation mask for input points. sem_mask (np.ndarray): semantic segmentation mask for input points.
Returns: Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: tuple[:obj:`BasePoints`, np.ndarray]:
- points (:obj:`BasePoints`): 3D Points. - points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray): The generated random samples. - choices (np.ndarray): The generated random samples.
...@@ -1438,7 +1466,7 @@ class BackgroundPointsFilter(BaseTransform): ...@@ -1438,7 +1466,7 @@ class BackgroundPointsFilter(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class VoxelBasedPointSampler(object): class VoxelBasedPointSampler(BaseTransform):
"""Voxel based point sampler. """Voxel based point sampler.
Apply voxel sampling to multiple sweep points. Apply voxel sampling to multiple sweep points.
...@@ -1450,7 +1478,10 @@ class VoxelBasedPointSampler(object): ...@@ -1450,7 +1478,10 @@ class VoxelBasedPointSampler(object):
for input points. for input points.
""" """
def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3): def __init__(self,
cur_sweep_cfg: dict,
prev_sweep_cfg: Optional[dict] = None,
time_dim: int = 3) -> None:
self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg) self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg)
self.cur_voxel_num = self.cur_voxel_generator._max_voxels self.cur_voxel_num = self.cur_voxel_generator._max_voxels
self.time_dim = time_dim self.time_dim = time_dim
...@@ -1463,7 +1494,8 @@ class VoxelBasedPointSampler(object): ...@@ -1463,7 +1494,8 @@ class VoxelBasedPointSampler(object):
self.prev_voxel_generator = None self.prev_voxel_generator = None
self.prev_voxel_num = 0 self.prev_voxel_num = 0
def _sample_points(self, points, sampler, point_dim): def _sample_points(self, points: np.ndarray, sampler: VoxelGenerator,
point_dim: int) -> np.ndarray:
"""Sample points for each points subset. """Sample points for each points subset.
Args: Args:
...@@ -1489,7 +1521,7 @@ class VoxelBasedPointSampler(object): ...@@ -1489,7 +1521,7 @@ class VoxelBasedPointSampler(object):
return sample_points return sample_points
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to sample points from multiple sweeps. """Call function to sample points from multiple sweeps.
Args: Args:
...@@ -1665,8 +1697,9 @@ class AffineResize(BaseTransform): ...@@ -1665,8 +1697,9 @@ class AffineResize(BaseTransform):
if 'gt_bboxes' in results: if 'gt_bboxes' in results:
results['gt_bboxes'] = results['gt_bboxes'][valid_index] results['gt_bboxes'] = results['gt_bboxes'][valid_index]
if 'gt_labels' in results: if 'gt_bboxes_labels' in results:
results['gt_labels'] = results['gt_labels'][valid_index] results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
valid_index]
if 'gt_masks' in results: if 'gt_masks' in results:
raise NotImplementedError( raise NotImplementedError(
'AffineResize only supports bbox.') 'AffineResize only supports bbox.')
...@@ -1771,6 +1804,7 @@ class AffineResize(BaseTransform): ...@@ -1771,6 +1804,7 @@ class AffineResize(BaseTransform):
return ref_point3 return ref_point3
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, ' repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'down_ratio={self.down_ratio}) ' repr_str += f'down_ratio={self.down_ratio}) '
...@@ -1791,7 +1825,7 @@ class RandomShiftScale(BaseTransform): ...@@ -1791,7 +1825,7 @@ class RandomShiftScale(BaseTransform):
aug_prob (float): The shifting and scaling probability. aug_prob (float): The shifting and scaling probability.
""" """
def __init__(self, shift_scale: Tuple[float], aug_prob: float): def __init__(self, shift_scale: Tuple[float], aug_prob: float) -> None:
self.shift_scale = shift_scale self.shift_scale = shift_scale
self.aug_prob = aug_prob self.aug_prob = aug_prob
...@@ -1830,7 +1864,484 @@ class RandomShiftScale(BaseTransform): ...@@ -1830,7 +1864,484 @@ class RandomShiftScale(BaseTransform):
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(shift_scale={self.shift_scale}, ' repr_str += f'(shift_scale={self.shift_scale}, '
repr_str += f'aug_prob={self.aug_prob}) ' repr_str += f'aug_prob={self.aug_prob}) '
return repr_str return repr_str
@TRANSFORMS.register_module()
class Resize3D(Resize):
def _resize_3d(self, results):
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
if 'centers_2d' in results:
results['centers_2d'] *= results['scale_factor'][:2]
results['cam2img'][0] *= np.array(results['scale_factor'][0])
results['cam2img'][1] *= np.array(results['scale_factor'][1])
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes, semantic
segmentation map and keypoints.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'img_shape',
and 'keep_ratio' keys are updated in result dict.
"""
super(Resize3D, self).transform(results)
self._resize_3d(results)
return results
@TRANSFORMS.register_module()
class RandomResize3D(RandomResize):
"""The difference between RandomResize3D and RandomResize:
1. Compared to RandomResize, this class would further
check if scale is already set in results.
2. During resizing, this class would modify the centers_2d
and cam2img with ``results['scale']``.
"""
def _resize_3d(self, results):
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
if 'centers_2d' in results:
results['centers_2d'] *= results['scale_factor'][:2]
results['cam2img'][0] *= np.array(results['scale_factor'][0])
results['cam2img'][1] *= np.array(results['scale_factor'][1])
def transform(self, results):
"""Transform function to resize images, bounding boxes, masks, semantic
segmentation map. Compared to RandomResize, this function would further
check if scale is already set in results.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
results['scale'] = self._random_scale()
self.resize.scale = results['scale']
results = self.resize(results)
self._resize_3d(results)
return results
@TRANSFORMS.register_module()
class RandomCrop3D(RandomCrop):
"""3D version of RandomCrop. RamdomCrop3D supports the modifications of
camera intrinsic matrix and using predefined randomness variable to do the
augmentation.
The absolute ``crop_size`` is sampled based on ``crop_type`` and
``image_size``, then the cropped results are generated.
Required Keys:
- img
- gt_bboxes (np.float32) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_masks (optional)
- gt_ignore_flags (optional)
- gt_seg_map (optional)
Added Keys:
- homography_matrix
Args:
crop_size (tuple): The relative ratio or absolute pixels of
height and width.
crop_type (str): One of "relative_range", "relative",
"absolute", "absolute_range". "relative" randomly crops
(h * crop_size[0], w * crop_size[1]) part from an input of size
(h, w). "relative_range" uniformly samples relative crop size from
range [crop_size[0], 1] and [crop_size[1], 1] for height and width
respectively. "absolute" crops from an input with absolute size
(crop_size[0], crop_size[1]). "absolute_range" uniformly samples
crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
in range [crop_size[0], min(w, crop_size[1])].
Defaults to "absolute".
allow_negative_crop (bool): Whether to allow a crop that does
not contain any bbox area. Defaults to False.
recompute_bbox (bool): Whether to re-compute the boxes based
on cropped instance masks. Defaults to False.
bbox_clip_border (bool): Whether clip the objects outside
the border of the image. Defaults to True.
rel_offset_h (tuple): The cropping interval of image height. Default
to (0., 1.).
rel_offset_w (tuple): The cropping interval of image width. Default
to (0., 1.).
Note:
- If the image is smaller than the absolute crop size, return the
original image.
- The keys for bboxes, labels and masks must be aligned. That is,
``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and
``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and
``gt_masks_ignore``.
- If the crop does not contain any gt-bbox region and
``allow_negative_crop`` is set to False, skip this image.
"""
def __init__(self,
crop_size,
crop_type='absolute',
allow_negative_crop=False,
recompute_bbox=False,
bbox_clip_border=True,
rel_offset_h=(0., 1.),
rel_offset_w=(0., 1.)):
super().__init__(
crop_size=crop_size,
crop_type=crop_type,
allow_negative_crop=allow_negative_crop,
recompute_bbox=recompute_bbox,
bbox_clip_border=bbox_clip_border)
# rel_offset specifies the relative offset range of cropping origin
# [0., 1.] means starting from 0*margin to 1*margin + 1
self.rel_offset_h = rel_offset_h
self.rel_offset_w = rel_offset_w
def _crop_data(self, results, crop_size, allow_negative_crop):
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
assert crop_size[0] > 0 and crop_size[1] > 0
for key in results.get('img_fields', ['img']):
img = results[key]
if 'img_crop_offset' not in results:
margin_h = max(img.shape[0] - crop_size[0], 0)
margin_w = max(img.shape[1] - crop_size[1], 0)
# TOCHECK: a little different from LIGA implementation
offset_h = np.random.randint(
self.rel_offset_h[0] * margin_h,
self.rel_offset_h[1] * margin_h + 1)
offset_w = np.random.randint(
self.rel_offset_w[0] * margin_w,
self.rel_offset_w[1] * margin_w + 1)
else:
offset_w, offset_h = results['img_crop_offset']
crop_h = min(crop_size[0], img.shape[0])
crop_w = min(crop_size[1], img.shape[1])
crop_y1, crop_y2 = offset_h, offset_h + crop_h
crop_x1, crop_x2 = offset_w, offset_w + crop_w
# crop the image
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
img_shape = img.shape
results[key] = img
results['img_shape'] = img_shape
# crop bboxes accordingly and clip to the image boundary
for key in results.get('bbox_fields', []):
# e.g. gt_bboxes and gt_bboxes_ignore
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
dtype=np.float32)
bboxes = results[key] - bbox_offset
if self.bbox_clip_border:
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
bboxes[:, 3] > bboxes[:, 1])
# If the crop does not contain any gt-bbox area and
# allow_negative_crop is False, skip this image.
if (key == 'gt_bboxes' and not valid_inds.any()
and not allow_negative_crop):
return None
results[key] = bboxes[valid_inds, :]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key = self.bbox2label.get(key)
if label_key in results:
results[label_key] = results[label_key][valid_inds]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key = self.bbox2mask.get(key)
if mask_key in results:
results[mask_key] = results[mask_key][
valid_inds.nonzero()[0]].crop(
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
if self.recompute_bbox:
results[key] = results[mask_key].get_bboxes()
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]
# manipulate camera intrinsic matrix
# needs to apply offset to K instead of P2 (on KITTI)
if isinstance(results['cam2img'], list):
# TODO ignore this, but should handle it in the future
pass
else:
K = results['cam2img'][:3, :3].copy()
inv_K = np.linalg.inv(K)
T = np.matmul(inv_K, results['cam2img'][:3])
K[0, 2] -= crop_x1
K[1, 2] -= crop_y1
offset_cam2img = np.matmul(K, T)
results['cam2img'][:offset_cam2img.shape[0], :offset_cam2img.
shape[1]] = offset_cam2img
results['img_crop_offset'] = [offset_w, offset_h]
return results
def transform(self, results):
"""Transform function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
image_size = results['img'].shape[:2]
if 'crop_size' not in results:
crop_size = self._get_crop_size(image_size)
results['crop_size'] = crop_size
else:
crop_size = results['crop_size']
results = self._crop_data(results, crop_size, self.allow_negative_crop)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(crop_size={self.crop_size}, '
repr_str += f'crop_type={self.crop_type}, '
repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border}), '
repr_str += f'rel_offset_h={self.rel_offset_h}), '
repr_str += f'rel_offset_w={self.rel_offset_w})'
return repr_str
@TRANSFORMS.register_module()
class PhotoMetricDistortion3D(PhotoMetricDistortion):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
PhotoMetricDistortion3D further support using predefined randomness
variable to do the augmentation.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Required Keys:
- img (np.uint8)
Modified Keys:
- img (np.float32)
Args:
brightness_delta (int): delta of brightness.
contrast_range (sequence): range of contrast.
saturation_range (sequence): range of saturation.
hue_delta (int): delta of hue.
"""
def transform(self, results: dict) -> dict:
"""Transform function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
assert 'img' in results, '`img` is not found in results'
img = results['img']
img = img.astype(np.float32)
if 'photometric_param' not in results:
photometric_param = self._random_flags()
results['photometric_param'] = photometric_param
else:
photometric_param = results['photometric_param']
(mode, brightness_flag, contrast_flag, saturation_flag, hue_flag,
swap_flag, delta_value, alpha_value, saturation_value, hue_value,
swap_value) = photometric_param
# random brightness
if brightness_flag:
img += delta_value
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
if mode == 1:
if contrast_flag:
img *= alpha_value
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if saturation_flag:
img[..., 1] *= saturation_value
# random hue
if hue_flag:
img[..., 0] += hue_value
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if contrast_flag:
img *= alpha_value
# randomly swap channels
if swap_flag:
img = img[..., swap_value]
results['img'] = img
return results
@TRANSFORMS.register_module()
class MultiViewWrapper(BaseTransform):
"""Wrap transformation from single-view into multi-view.
The wrapper processes the images from multi-view one by one. For each
image, it constructs a pseudo dict according to the keys specified by the
'process_fields' parameter. After the transformation is finished, desired
information can be collected by specifying the keys in the 'collected_keys'
parameter. Multi-view images share the same transformation parameters
but do not share the same magnitude when a random transformation is
conducted.
Args:
transforms (list[dict]): A list of dict specifying the transformations
for the monocular situation.
override_aug_config (bool): flag of whether to use the same aug config
for multiview image. Default to True.
process_fields (list): Desired keys that the transformations should
be conducted on. Default to ['img', 'cam2img', 'lidar2cam'],
collected_keys (list): Collect information in transformation
like rotate angles, crop roi, and flip state. Default to
['scale', 'scale_factor', 'crop',
'crop_offset', 'ori_shape',
'pad_shape', 'img_shape',
'pad_fixed_size', 'pad_size_divisor',
'flip', 'flip_direction', 'rotate'],
randomness_keys (list): The keys that related to the randomness
in transformation Default to
['scale', 'scale_factor', 'crop_size', 'flip',
'flip_direction', 'photometric_param']
"""
def __init__(self,
transforms: dict,
override_aug_config: bool = True,
process_fields: list = ['img', 'cam2img', 'lidar2cam'],
collected_keys: list = [
'scale', 'scale_factor', 'crop', 'img_crop_offset',
'ori_shape', 'pad_shape', 'img_shape', 'pad_fixed_size',
'pad_size_divisor', 'flip', 'flip_direction', 'rotate'
],
randomness_keys: list = [
'scale', 'scale_factor', 'crop_size', 'img_crop_offset',
'flip', 'flip_direction', 'photometric_param'
]):
self.transforms = Compose(transforms)
self.override_aug_config = override_aug_config
self.collected_keys = collected_keys
self.process_fields = process_fields
self.randomness_keys = randomness_keys
def transform(self, input_dict):
"""Transform function to do the transform for multiview image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: output dict after transformtaion
"""
# store the augmentation related keys for each image.
for key in self.collected_keys:
if key not in input_dict or \
not isinstance(input_dict[key], list):
input_dict[key] = []
prev_process_dict = {}
for img_id in range(len(input_dict['img'])):
process_dict = {}
# override the process dict (e.g. scale in random scale,
# crop_size in random crop, flip, flip_direction in
# random flip)
if img_id != 0 and self.override_aug_config:
for key in self.randomness_keys:
if key in prev_process_dict:
process_dict[key] = prev_process_dict[key]
for key in self.process_fields:
if key in input_dict:
process_dict[key] = input_dict[key][img_id]
process_dict = self.transforms(process_dict)
# store the randomness variable in transformation.
prev_process_dict = process_dict
# store the related results to results_dict
for key in self.process_fields:
if key in process_dict:
input_dict[key][img_id] = process_dict[key]
# update the keys
for key in self.collected_keys:
if key in process_dict:
if len(input_dict[key]) == img_id + 1:
input_dict[key][img_id] = process_dict[key]
else:
input_dict[key].append(process_dict[key])
for key in self.collected_keys:
if len(input_dict[key]) == 0:
input_dict.pop(key)
return input_dict
...@@ -23,8 +23,8 @@ class WaymoDataset(KittiDataset): ...@@ -23,8 +23,8 @@ class WaymoDataset(KittiDataset):
Args: Args:
data_root (str): Path of dataset root. data_root (str): Path of dataset root.
ann_file (str): Path of annotation file. ann_file (str): Path of annotation file.
data_prefix (list[dict]): data prefix for point cloud and data_prefix (dict): data prefix for point cloud and
camera data dict, default to dict( camera data dict. Default to dict(
pts='velodyne', pts='velodyne',
CAM_FRONT='image_0', CAM_FRONT='image_0',
CAM_FRONT_RIGHT='image_1', CAM_FRONT_RIGHT='image_1',
...@@ -34,13 +34,14 @@ class WaymoDataset(KittiDataset): ...@@ -34,13 +34,14 @@ class WaymoDataset(KittiDataset):
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to `dict(use_lidar=True)`. as input. Defaults to dict(use_lidar=True).
default_cam_key (str, optional): Default camera key for lidar2img default_cam_key (str, optional): Default camera key for lidar2img
association. association. Defaults to 'CAM_FRONT'.
box_type_3d (str, optional): Type of 3D box of this dataset. box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes Defaults to 'LiDAR' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates. - 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset. - 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates. - 'Camera': Box in camera coordinates.
...@@ -48,16 +49,18 @@ class WaymoDataset(KittiDataset): ...@@ -48,16 +49,18 @@ class WaymoDataset(KittiDataset):
Defaults to True. Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
pcd_limit_range (list, optional): The range of point cloud used to pcd_limit_range (list[float], optional): The range of point cloud
filter invalid predicted boxes. used to filter invalid predicted boxes.
Default: [-85, -85, -5, 85, 85, 5]. Defaults to [-85, -85, -5, 85, 85, 5].
cam_sync_instances (bool, optional): If use the camera sync label cam_sync_instances (bool, optional): If use the camera sync label
supported from waymo version 1.3.1. supported from waymo version 1.3.1. Defaults to False.
load_interval (int, optional): load frame interval. load_interval (int, optional): load frame interval.
Defaults to 1.
task (str, optional): task for 3D detection (lidar, mono3d). task (str, optional): task for 3D detection (lidar, mono3d).
lidar: take all the ground trurh in the frame. lidar: take all the ground trurh in the frame.
mono3d: take the groundtruth that can be seen in the cam. mono3d: take the groundtruth that can be seen in the cam.
max_sweeps (int, optional): max sweep for each frame. Defaults to 'lidar'.
max_sweeps (int, optional): max sweep for each frame. Defaults to 0.
""" """
METAINFO = {'CLASSES': ('Car', 'Pedestrian', 'Cyclist')} METAINFO = {'CLASSES': ('Car', 'Pedestrian', 'Cyclist')}
...@@ -80,7 +83,7 @@ class WaymoDataset(KittiDataset): ...@@ -80,7 +83,7 @@ class WaymoDataset(KittiDataset):
pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0], pcd_limit_range: List[float] = [0, -40, -3, 70.4, 40, 0.0],
cam_sync_instances=False, cam_sync_instances=False,
load_interval=1, load_interval=1,
task='lidar', task='lidar_det',
max_sweeps=0, max_sweeps=0,
**kwargs): **kwargs):
self.load_interval = load_interval self.load_interval = load_interval
...@@ -127,20 +130,19 @@ class WaymoDataset(KittiDataset): ...@@ -127,20 +130,19 @@ class WaymoDataset(KittiDataset):
ann_info = Det3DDataset.parse_ann_info(self, info) ann_info = Det3DDataset.parse_ann_info(self, info)
if ann_info is None: if ann_info is None:
# empty instance # empty instance
anns_results = {} ann_info = {}
anns_results['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32) ann_info['gt_bboxes_3d'] = np.zeros((0, 7), dtype=np.float32)
anns_results['gt_labels_3d'] = np.zeros(0, dtype=np.int64) ann_info['gt_labels_3d'] = np.zeros(0, dtype=np.int64)
return anns_results
ann_info = self._remove_dontcare(ann_info) ann_info = self._remove_dontcare(ann_info)
# in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam # in kitti, lidar2cam = R0_rect @ Tr_velo_to_cam
# convert gt_bboxes_3d to velodyne coordinates with `lidar2cam` # convert gt_bboxes_3d to velodyne coordinates with `lidar2cam`
if 'gt_bboxes' in ann_info: if 'gt_bboxes' in ann_info:
gt_bboxes = ann_info['gt_bboxes'] gt_bboxes = ann_info['gt_bboxes']
gt_labels = ann_info['gt_labels'] gt_bboxes_labels = ann_info['gt_bboxes_labels']
else: else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32) gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64) gt_bboxes_labels = np.zeros(0, dtype=np.int64)
if 'centers_2d' in ann_info: if 'centers_2d' in ann_info:
centers_2d = ann_info['centers_2d'] centers_2d = ann_info['centers_2d']
depths = ann_info['depths'] depths = ann_info['depths']
...@@ -148,25 +150,27 @@ class WaymoDataset(KittiDataset): ...@@ -148,25 +150,27 @@ class WaymoDataset(KittiDataset):
centers_2d = np.zeros((0, 2), dtype=np.float32) centers_2d = np.zeros((0, 2), dtype=np.float32)
depths = np.zeros((0), dtype=np.float32) depths = np.zeros((0), dtype=np.float32)
if self.task == 'mono3d': if self.task == 'mono_det':
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(
ann_info['gt_bboxes_3d'], ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1], box_dim=ann_info['gt_bboxes_3d'].shape[-1],
origin=(0.5, 0.5, 0.5)) origin=(0.5, 0.5, 0.5))
else: else:
# in waymo, lidar2cam = R0_rect @ Tr_velo_to_cam
# convert gt_bboxes_3d to velodyne coordinates with `lidar2cam`
lidar2cam = np.array( lidar2cam = np.array(
info['images'][self.default_cam_key]['lidar2cam']) info['images'][self.default_cam_key]['lidar2cam'])
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(
ann_info['gt_bboxes_3d']).convert_to(self.box_mode_3d, ann_info['gt_bboxes_3d']).convert_to(self.box_mode_3d,
np.linalg.inv(lidar2cam)) np.linalg.inv(lidar2cam))
ann_info['gt_bboxes_3d'] = gt_bboxes_3d
anns_results = dict( anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=ann_info['gt_labels_3d'], gt_labels_3d=ann_info['gt_labels_3d'],
gt_bboxes=gt_bboxes, gt_bboxes=gt_bboxes,
gt_labels=gt_labels, gt_bboxes_labels=gt_bboxes_labels,
centers_2d=centers_2d, centers_2d=centers_2d,
depths=depths) depths=depths)
...@@ -181,7 +185,7 @@ class WaymoDataset(KittiDataset): ...@@ -181,7 +185,7 @@ class WaymoDataset(KittiDataset):
def parse_data_info(self, info: dict) -> dict: def parse_data_info(self, info: dict) -> dict:
"""if task is lidar or multiview det, use super() method elif task is """if task is lidar or multiview det, use super() method elif task is
mono3d, split the info from frame-wise to img-wise.""" mono3d, split the info from frame-wise to img-wise."""
if self.task != 'mono3d': if self.task != 'mono_det':
if self.cam_sync_instances: if self.cam_sync_instances:
# use the cam sync labels # use the cam sync labels
info['instances'] = info['cam_sync_instances'] info['instances'] = info['cam_sync_instances']
...@@ -217,7 +221,7 @@ class WaymoDataset(KittiDataset): ...@@ -217,7 +221,7 @@ class WaymoDataset(KittiDataset):
# TODO check if need to modify the sample id # TODO check if need to modify the sample id
# TODO check when will use it except for evaluation. # TODO check when will use it except for evaluation.
camera_info['sample_id'] = info['sample_id'] camera_info['sample_idx'] = info['sample_idx']
if not self.test_mode: if not self.test_mode:
# used in training # used in training
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .hooks import Det3DVisualizationHook from .hooks import BenchmarkHook, Det3DVisualizationHook
__all__ = ['Det3DVisualizationHook'] __all__ = ['Det3DVisualizationHook', 'BenchmarkHook']
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .benchmark_hook import BenchmarkHook
from .visualization_hook import Det3DVisualizationHook from .visualization_hook import Det3DVisualizationHook
__all__ = ['Det3DVisualizationHook'] __all__ = ['Det3DVisualizationHook', 'BenchmarkHook']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment