Commit 7d227c79 authored by wangtai's avatar wangtai Committed by zhangwenwei
Browse files

Update mmdet3d/datasets/scannet_dataset.py, mmdet3d/datasets/kitti_dataset.py,...

Update mmdet3d/datasets/scannet_dataset.py, mmdet3d/datasets/kitti_dataset.py, mmdet3d/datasets/sunrgbd_dataset.py, mmdet3d/datasets/nuscenes_dataset.py files
parent addc86ad
......@@ -67,9 +67,31 @@ class Custom3DDataset(Dataset):
self._set_group_flag()
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
return mmcv.load(ann_file)
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Standard input_dict consists of the
data information.
- sample_idx (str): sample index
- pts_filename (str): filename of point clouds
- file_name (str): filename of point clouds
- ann_info (dict): annotation info
"""
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path'])
......@@ -87,6 +109,21 @@ class Custom3DDataset(Dataset):
return input_dict
def pre_pipeline(self, results):
"""Initialization before data preparation.
Args:
dict: Dict before data preprocessing.
- img_fields (list): image fields
- bbox3d_fields (list): 3D bounding boxes fields
- pts_mask_fields (list): mask fields of points
- pts_seg_fields (list): mask fields of point segments
- bbox_fields (list): fields of bounding boxes
- mask_fields (list): fields of masks
- seg_fields (list): segment fields
- box_type_3d (str): 3D box type
- box_mode_3d (str): 3D box mode
"""
results['img_fields'] = []
results['bbox3d_fields'] = []
results['pts_mask_fields'] = []
......@@ -98,6 +135,14 @@ class Custom3DDataset(Dataset):
results['box_mode_3d'] = self.box_mode_3d
def prepare_train_data(self, index):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict corresponding to the index.
"""
input_dict = self.get_data_info(index)
if input_dict is None:
return None
......@@ -109,6 +154,14 @@ class Custom3DDataset(Dataset):
return example
def prepare_test_data(self, index):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict corresponding to the index.
"""
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
......@@ -145,6 +198,19 @@ class Custom3DDataset(Dataset):
outputs,
pklfile_prefix=None,
submission_prefix=None):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving json
files when jsonfile_prefix is not specified.
"""
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
......
......@@ -50,6 +50,14 @@ class Kitti2DDataset(CustomDataset):
"""
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
self.data_infos = mmcv.load(ann_file)
self.cat2label = {
cat_name: i
......@@ -66,6 +74,18 @@ class Kitti2DDataset(CustomDataset):
return valid_inds
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- bboxes (np.ndarray): ground truth bboxes
- labels (np.ndarray): labels of ground truths
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
annos = info['annos']
......@@ -87,6 +107,15 @@ class Kitti2DDataset(CustomDataset):
return anns_results
def prepare_train_img(self, idx):
"""Training image preparation.
Args:
index (int): Index for accessing the target image data.
Returns:
dict: Training image data dict after preprocessing
corresponding to the index.
"""
img_raw_info = self.data_infos[idx]['image']
img_info = dict(filename=img_raw_info['image_path'])
ann_info = self.get_ann_info(idx)
......@@ -99,6 +128,15 @@ class Kitti2DDataset(CustomDataset):
return self.pipeline(results)
def prepare_test_img(self, idx):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target image data.
Returns:
dict: Testing image data dict after preprocessing
corresponding to the index.
"""
img_raw_info = self.data_infos[idx]['image']
img_info = dict(filename=img_raw_info['image_path'])
results = dict(img_info=img_info)
......@@ -108,11 +146,29 @@ class Kitti2DDataset(CustomDataset):
return self.pipeline(results)
def drop_arrays_by_name(self, gt_names, used_classes):
"""Drop irrelevant ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be dropped.
"""
inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
inds = np.array(inds, dtype=np.int64)
return inds
def keep_arrays_by_name(self, gt_names, used_classes):
"""Keep useful ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be keeped.
"""
inds = [i for i, x in enumerate(gt_names) if x in used_classes]
inds = np.array(inds, dtype=np.int64)
return inds
......@@ -125,6 +181,17 @@ class Kitti2DDataset(CustomDataset):
return result_files
def evaluate(self, result_files, eval_types=None):
"""Evaluation in KITTI protocol.
Args:
result_files (str): Path of result files.
eval_types (str): Types of evaluation. Default: None.
KITTI dataset only support 'bbox' evaluation type.
Returns:
tuple (str, dict): Average precision results in str format
and average precision results in dict format.
"""
from mmdet3d.core.evaluation import kitti_eval
eval_types = ['bbox'] if not eval_types else eval_types
assert eval_types in ('bbox', ['bbox'
......
......@@ -83,6 +83,23 @@ class KittiDataset(Custom3DDataset):
return pts_filename
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Standard input_dict consists of the
data information.
- sample_idx (str): sample index
- pts_filename (str): filename of point clouds
- img_prefix (str | None): prefix of image files
- img_info (dict): image info
- lidar2img (list[np.ndarray], optional): transformations from
lidar to different cameras
- ann_info (dict): annotation info
"""
info = self.data_infos[index]
sample_idx = info['image']['image_idx']
img_filename = os.path.join(self.data_root,
......@@ -109,6 +126,22 @@ class KittiDataset(Custom3DDataset):
return input_dict
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``LiDARInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- gt_bboxes (np.ndarray): 2D ground truth bboxes
- gt_labels (np.ndarray): labels of ground truths
- gt_names (list[str]): class names of ground truths
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
rect = info['calib']['R0_rect'].astype(np.float32)
......@@ -153,16 +186,43 @@ class KittiDataset(Custom3DDataset):
return anns_results
def drop_arrays_by_name(self, gt_names, used_classes):
"""Drop irrelevant ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be dropped.
"""
inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
inds = np.array(inds, dtype=np.int64)
return inds
def keep_arrays_by_name(self, gt_names, used_classes):
"""Keep useful ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be keeped.
"""
inds = [i for i, x in enumerate(gt_names) if x in used_classes]
inds = np.array(inds, dtype=np.int64)
return inds
def remove_dontcare(self, ann_info):
"""Remove annotations that do not need to be cared.
Args:
ann_info (dict): Dict of annotation infos. The ``'DontCare'``
annotations will be removed according to ann_file['name'].
Returns:
dict: Annotations after filtering.
"""
img_filtered_annotations = {}
relevant_annotation_indices = [
i for i, x in enumerate(ann_info['name']) if x != 'DontCare'
......@@ -176,6 +236,23 @@ class KittiDataset(Custom3DDataset):
outputs,
pklfile_prefix=None,
submission_prefix=None):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submitted files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing
the json filepaths, tmp_dir is the temporal directory created
for saving json files when jsonfile_prefix is not specified.
"""
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
......@@ -390,7 +467,7 @@ class KittiDataset(Custom3DDataset):
pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file.
Return:
Returns:
List[dict]: A list of dict have the kitti format
"""
assert len(net_outputs) == len(self.data_infos)
......@@ -553,6 +630,12 @@ class KittiDataset(Custom3DDataset):
)
def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results):
example = self.prepare_test_data(i)
......
......@@ -139,6 +139,14 @@ class NuScenesDataset(Custom3DDataset):
)
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations sorted by timestamps.
"""
data = mmcv.load(ann_file)
data_infos = list(sorted(data['infos'], key=lambda e: e['timestamp']))
data_infos = data_infos[::self.load_interval]
......@@ -147,6 +155,24 @@ class NuScenesDataset(Custom3DDataset):
return data_infos
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Standard input_dict consists of the
data information.
- sample_idx (str): sample index
- pts_filename (str): filename of point clouds
- sweeps (list[dict]): infos of sweeps
- timestamp (float): sample timestamp
- img_filename (str, optional): image filename
- lidar2img (list[np.ndarray], optional): transformations from
lidar to different cameras
- ann_info (dict): annotation info
"""
info = self.data_infos[index]
# standard protocal modified from SECOND.Pytorch
......@@ -188,6 +214,20 @@ class NuScenesDataset(Custom3DDataset):
return input_dict
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``LiDARInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- gt_names (list[str]): class names of ground truths
"""
info = self.data_infos[index]
# filter out bbox containing no points
mask = info['num_lidar_pts'] > 0
......@@ -221,6 +261,17 @@ class NuScenesDataset(Custom3DDataset):
return anns_results
def _format_bbox(self, results, jsonfile_prefix=None):
"""Convert the results to the standard format.
Args:
results (list[dict]): Testing results of the dataset.
jsonfile_prefix (str): The prefix of the output jsonfile.
You can specify the output directory/filename by
modifying the jsonfile_prefix. Default: None.
Returns:
str: Path of the output json file.
"""
nusc_annos = {}
mapped_class_names = self.CLASSES
......@@ -283,6 +334,19 @@ class NuScenesDataset(Custom3DDataset):
logger=None,
metric='bbox',
result_name='pts_bbox'):
"""Evaluation for a single model in nuScenes protocol.
Args:
result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'.
result_name (str): Result name in the metric prefix.
Default: 'pts_bbox'.
Returns:
dict: Dictionary of evaluation details.
"""
from nuscenes import NuScenes
from nuscenes.eval.detection.evaluate import NuScenesEval
......@@ -400,6 +464,12 @@ class NuScenesDataset(Custom3DDataset):
return results_dict
def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
for i, result in enumerate(results):
example = self.prepare_test_data(i)
points = example['points'][0]._data.numpy()
......@@ -421,6 +491,18 @@ class NuScenesDataset(Custom3DDataset):
def output_to_nusc_box(detection):
"""Convert the output to the box class in the nuScenes.
Args:
detection (dict): Detection results.
- boxes_3d (:obj:``BaseInstance3DBoxes``): detection bbox
- scores_3d (torch.Tensor): detection scores
- labels_3d (torch.Tensor): predicted box labels
Returns:
list[:obj:``NuScenesBox``]: List of standard NuScenesBoxes.
"""
box3d = detection['boxes_3d']
scores = detection['scores_3d'].numpy()
labels = detection['labels_3d'].numpy()
......@@ -456,6 +538,21 @@ def lidar_nusc_box_to_global(info,
classes,
eval_configs,
eval_version='detection_cvpr_2019'):
"""Convert the box from ego to global coordinate.
Args:
info (dict): Info for a specific sample data, including the
calibration information.
boxes (list[:obj:``NuScenesBox``]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version.
Default: 'detection_cvpr_2019'
Returns:
list: List of standard NuScenesBoxes in the global
coordinate.
"""
box_list = []
for box in boxes:
# Move box to ego vehicle coord system
......
......@@ -51,6 +51,16 @@ class BatchSampler:
@OBJECTSAMPLERS.register_module()
class DataBaseSampler(object):
"""Class for sampling data from the ground truth database.
Args:
info_path (str): Path of groundtruth database info.
data_root (str): Path of groundtruth database.
rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str]): List of classes. Default: None.
"""
def __init__(self,
info_path,
......@@ -104,6 +114,15 @@ class DataBaseSampler(object):
@staticmethod
def filter_by_difficulty(db_infos, removed_difficulty):
"""Filter ground truths by difficulties.
Args:
db_infos (dict): Info of groundtruth database.
removed_difficulty (list): Difficulties that are not qualified.
Returns:
dict: Info of database after filtering.
"""
new_db_infos = {}
for key, dinfos in db_infos.items():
new_db_infos[key] = [
......@@ -114,6 +133,16 @@ class DataBaseSampler(object):
@staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict):
"""Filter ground truths by number of points in the bbox.
Args:
db_infos (dict): Info of groundtruth database.
min_gt_points_dict (dict): Different number of minimum points
needed for different categories of ground truths.
Returns:
dict: Info of database after filtering.
"""
for name, min_num in min_gt_points_dict.items():
min_num = int(min_num)
if min_num > 0:
......@@ -125,6 +154,22 @@ class DataBaseSampler(object):
return db_infos
def sample_all(self, gt_bboxes, gt_labels, img=None):
"""Sampling all categories of bboxes.
Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Labels of boxes.
Returns:
dict: Dict of sampled 'pseudo ground truths'.
- gt_labels_3d (np.ndarray): labels of ground truths:
labels of sampled ground truths
- gt_bboxes_3d (:obj:``BaseInstance3DBoxes``):
sampled 3D bounding boxes
- points (np.ndarray): sampled points
- group_ids (np.ndarray): ids of sampled ground truths
"""
sampled_num_dict = {}
sample_num_per_class = []
for class_name, max_sample_num in zip(self.sample_classes,
......@@ -198,6 +243,16 @@ class DataBaseSampler(object):
return ret
def sample_class_v2(self, name, num, gt_bboxes):
"""Sampling specific categories of bounding boxes.
Args:
name (str): Class of objects to be sampled.
num (int): Number of sampled bboxes.
gt_bboxes (np.ndarray): Ground truth boxes.
Returns:
list[dict]: Valid samples after collision test.
"""
sampled = self.sampler_dict[name].sample(num)
sampled = copy.deepcopy(sampled)
num_gt = gt_bboxes.shape[0]
......
......@@ -114,7 +114,7 @@ class PointSegClassMapping(object):
others as len(valid_cat_ids).
Args:
valid_cat_ids (tuple[int): A tuple of valid category.
valid_cat_ids (tuple[int]): A tuple of valid category.
"""
def __init__(self, valid_cat_ids):
......
......@@ -171,12 +171,12 @@ class ObjectNoise(object):
"""Apply noise to each GT objects in the scene.
Args:
translation_std (list, optional): Standard deviation of the
translation_std (list[float], optional): Standard deviation of the
distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25].
global_rot_range (list, optional): Global rotation to the scene.
global_rot_range (list[float], optional): Global rotation to the scene.
Defaults to [0.0, 0.0].
rot_range (list, optional): Object rotation range.
rot_range (list[float], optional): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267].
num_try (int, optional): Number of times to try if the noise applied is
invalid. Defaults to 100.
......@@ -429,8 +429,9 @@ class IndoorPointSample(object):
return_choices (bool): Whether return choice.
Returns:
points (ndarray): 3D Points.
choices (ndarray): The generated random samples.
tuple (np.ndarray, np.ndarray) | np.ndarray:
points (np.ndarray): 3D Points.
choices (np.ndarray, optional): The generated random samples.
"""
if replace is None:
replace = (points.shape[0] < num_samples)
......
......@@ -64,6 +64,21 @@ class ScanNetDataset(Custom3DDataset):
test_mode=test_mode)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``DepthInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- pts_instance_mask_path (str): path of instance masks
- pts_semantic_mask_path (str): path of semantic masks
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
if info['annos']['gt_num'] != 0:
......@@ -94,6 +109,12 @@ class ScanNetDataset(Custom3DDataset):
return anns_results
def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results):
data_info = self.data_infos[i]
......
......@@ -62,6 +62,21 @@ class SUNRGBDDataset(Custom3DDataset):
test_mode=test_mode)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``DepthInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- pts_instance_mask_path (str): path of instance masks
- pts_semantic_mask_path (str): path of semantic masks
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
if info['annos']['gt_num'] != 0:
......@@ -81,6 +96,12 @@ class SUNRGBDDataset(Custom3DDataset):
return anns_results
def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results):
data_info = self.data_infos[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment