Commit 7d227c79 authored by wangtai's avatar wangtai Committed by zhangwenwei
Browse files

Update mmdet3d/datasets/scannet_dataset.py, mmdet3d/datasets/kitti_dataset.py,...

Update mmdet3d/datasets/scannet_dataset.py, mmdet3d/datasets/kitti_dataset.py, mmdet3d/datasets/sunrgbd_dataset.py, mmdet3d/datasets/nuscenes_dataset.py files
parent addc86ad
...@@ -67,9 +67,31 @@ class Custom3DDataset(Dataset): ...@@ -67,9 +67,31 @@ class Custom3DDataset(Dataset):
self._set_group_flag() self._set_group_flag()
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
return mmcv.load(ann_file) return mmcv.load(ann_file)
def get_data_info(self, index): def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Standard input_dict consists of the
data information.
- sample_idx (str): sample index
- pts_filename (str): filename of point clouds
- file_name (str): filename of point clouds
- ann_info (dict): annotation info
"""
info = self.data_infos[index] info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx'] sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path']) pts_filename = osp.join(self.data_root, info['pts_path'])
...@@ -87,6 +109,21 @@ class Custom3DDataset(Dataset): ...@@ -87,6 +109,21 @@ class Custom3DDataset(Dataset):
return input_dict return input_dict
def pre_pipeline(self, results): def pre_pipeline(self, results):
"""Initialization before data preparation.
Args:
dict: Dict before data preprocessing.
- img_fields (list): image fields
- bbox3d_fields (list): 3D bounding boxes fields
- pts_mask_fields (list): mask fields of points
- pts_seg_fields (list): mask fields of point segments
- bbox_fields (list): fields of bounding boxes
- mask_fields (list): fields of masks
- seg_fields (list): segment fields
- box_type_3d (str): 3D box type
- box_mode_3d (str): 3D box mode
"""
results['img_fields'] = [] results['img_fields'] = []
results['bbox3d_fields'] = [] results['bbox3d_fields'] = []
results['pts_mask_fields'] = [] results['pts_mask_fields'] = []
...@@ -98,6 +135,14 @@ class Custom3DDataset(Dataset): ...@@ -98,6 +135,14 @@ class Custom3DDataset(Dataset):
results['box_mode_3d'] = self.box_mode_3d results['box_mode_3d'] = self.box_mode_3d
def prepare_train_data(self, index): def prepare_train_data(self, index):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict corresponding to the index.
"""
input_dict = self.get_data_info(index) input_dict = self.get_data_info(index)
if input_dict is None: if input_dict is None:
return None return None
...@@ -109,6 +154,14 @@ class Custom3DDataset(Dataset): ...@@ -109,6 +154,14 @@ class Custom3DDataset(Dataset):
return example return example
def prepare_test_data(self, index): def prepare_test_data(self, index):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict corresponding to the index.
"""
input_dict = self.get_data_info(index) input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict) self.pre_pipeline(input_dict)
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
...@@ -145,6 +198,19 @@ class Custom3DDataset(Dataset): ...@@ -145,6 +198,19 @@ class Custom3DDataset(Dataset):
outputs, outputs,
pklfile_prefix=None, pklfile_prefix=None,
submission_prefix=None): submission_prefix=None):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving json
files when jsonfile_prefix is not specified.
"""
if pklfile_prefix is None: if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory() tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results') pklfile_prefix = osp.join(tmp_dir.name, 'results')
......
...@@ -50,6 +50,14 @@ class Kitti2DDataset(CustomDataset): ...@@ -50,6 +50,14 @@ class Kitti2DDataset(CustomDataset):
""" """
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
self.data_infos = mmcv.load(ann_file) self.data_infos = mmcv.load(ann_file)
self.cat2label = { self.cat2label = {
cat_name: i cat_name: i
...@@ -66,6 +74,18 @@ class Kitti2DDataset(CustomDataset): ...@@ -66,6 +74,18 @@ class Kitti2DDataset(CustomDataset):
return valid_inds return valid_inds
def get_ann_info(self, index): def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- bboxes (np.ndarray): ground truth bboxes
- labels (np.ndarray): labels of ground truths
"""
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index] info = self.data_infos[index]
annos = info['annos'] annos = info['annos']
...@@ -87,6 +107,15 @@ class Kitti2DDataset(CustomDataset): ...@@ -87,6 +107,15 @@ class Kitti2DDataset(CustomDataset):
return anns_results return anns_results
def prepare_train_img(self, idx): def prepare_train_img(self, idx):
"""Training image preparation.
Args:
index (int): Index for accessing the target image data.
Returns:
dict: Training image data dict after preprocessing
corresponding to the index.
"""
img_raw_info = self.data_infos[idx]['image'] img_raw_info = self.data_infos[idx]['image']
img_info = dict(filename=img_raw_info['image_path']) img_info = dict(filename=img_raw_info['image_path'])
ann_info = self.get_ann_info(idx) ann_info = self.get_ann_info(idx)
...@@ -99,6 +128,15 @@ class Kitti2DDataset(CustomDataset): ...@@ -99,6 +128,15 @@ class Kitti2DDataset(CustomDataset):
return self.pipeline(results) return self.pipeline(results)
def prepare_test_img(self, idx): def prepare_test_img(self, idx):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target image data.
Returns:
dict: Testing image data dict after preprocessing
corresponding to the index.
"""
img_raw_info = self.data_infos[idx]['image'] img_raw_info = self.data_infos[idx]['image']
img_info = dict(filename=img_raw_info['image_path']) img_info = dict(filename=img_raw_info['image_path'])
results = dict(img_info=img_info) results = dict(img_info=img_info)
...@@ -108,11 +146,29 @@ class Kitti2DDataset(CustomDataset): ...@@ -108,11 +146,29 @@ class Kitti2DDataset(CustomDataset):
return self.pipeline(results) return self.pipeline(results)
def drop_arrays_by_name(self, gt_names, used_classes): def drop_arrays_by_name(self, gt_names, used_classes):
"""Drop irrelevant ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be dropped.
"""
inds = [i for i, x in enumerate(gt_names) if x not in used_classes] inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
inds = np.array(inds, dtype=np.int64) inds = np.array(inds, dtype=np.int64)
return inds return inds
def keep_arrays_by_name(self, gt_names, used_classes): def keep_arrays_by_name(self, gt_names, used_classes):
"""Keep useful ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be keeped.
"""
inds = [i for i, x in enumerate(gt_names) if x in used_classes] inds = [i for i, x in enumerate(gt_names) if x in used_classes]
inds = np.array(inds, dtype=np.int64) inds = np.array(inds, dtype=np.int64)
return inds return inds
...@@ -125,6 +181,17 @@ class Kitti2DDataset(CustomDataset): ...@@ -125,6 +181,17 @@ class Kitti2DDataset(CustomDataset):
return result_files return result_files
def evaluate(self, result_files, eval_types=None): def evaluate(self, result_files, eval_types=None):
"""Evaluation in KITTI protocol.
Args:
result_files (str): Path of result files.
eval_types (str): Types of evaluation. Default: None.
KITTI dataset only support 'bbox' evaluation type.
Returns:
tuple (str, dict): Average precision results in str format
and average precision results in dict format.
"""
from mmdet3d.core.evaluation import kitti_eval from mmdet3d.core.evaluation import kitti_eval
eval_types = ['bbox'] if not eval_types else eval_types eval_types = ['bbox'] if not eval_types else eval_types
assert eval_types in ('bbox', ['bbox' assert eval_types in ('bbox', ['bbox'
......
...@@ -83,6 +83,23 @@ class KittiDataset(Custom3DDataset): ...@@ -83,6 +83,23 @@ class KittiDataset(Custom3DDataset):
return pts_filename return pts_filename
def get_data_info(self, index): def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Standard input_dict consists of the
data information.
- sample_idx (str): sample index
- pts_filename (str): filename of point clouds
- img_prefix (str | None): prefix of image files
- img_info (dict): image info
- lidar2img (list[np.ndarray], optional): transformations from
lidar to different cameras
- ann_info (dict): annotation info
"""
info = self.data_infos[index] info = self.data_infos[index]
sample_idx = info['image']['image_idx'] sample_idx = info['image']['image_idx']
img_filename = os.path.join(self.data_root, img_filename = os.path.join(self.data_root,
...@@ -109,6 +126,22 @@ class KittiDataset(Custom3DDataset): ...@@ -109,6 +126,22 @@ class KittiDataset(Custom3DDataset):
return input_dict return input_dict
def get_ann_info(self, index): def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``LiDARInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- gt_bboxes (np.ndarray): 2D ground truth bboxes
- gt_labels (np.ndarray): labels of ground truths
- gt_names (list[str]): class names of ground truths
"""
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index] info = self.data_infos[index]
rect = info['calib']['R0_rect'].astype(np.float32) rect = info['calib']['R0_rect'].astype(np.float32)
...@@ -153,16 +186,43 @@ class KittiDataset(Custom3DDataset): ...@@ -153,16 +186,43 @@ class KittiDataset(Custom3DDataset):
return anns_results return anns_results
def drop_arrays_by_name(self, gt_names, used_classes): def drop_arrays_by_name(self, gt_names, used_classes):
"""Drop irrelevant ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be dropped.
"""
inds = [i for i, x in enumerate(gt_names) if x not in used_classes] inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
inds = np.array(inds, dtype=np.int64) inds = np.array(inds, dtype=np.int64)
return inds return inds
def keep_arrays_by_name(self, gt_names, used_classes): def keep_arrays_by_name(self, gt_names, used_classes):
"""Keep useful ground truths by name.
Args:
gt_names (list[str]): Names of ground truths.
used_classes (list[str]): Classes of interest.
Returns:
np.ndarray: Indices of ground truths that will be keeped.
"""
inds = [i for i, x in enumerate(gt_names) if x in used_classes] inds = [i for i, x in enumerate(gt_names) if x in used_classes]
inds = np.array(inds, dtype=np.int64) inds = np.array(inds, dtype=np.int64)
return inds return inds
def remove_dontcare(self, ann_info): def remove_dontcare(self, ann_info):
"""Remove annotations that do not need to be cared.
Args:
ann_info (dict): Dict of annotation infos. The ``'DontCare'``
annotations will be removed according to ann_file['name'].
Returns:
dict: Annotations after filtering.
"""
img_filtered_annotations = {} img_filtered_annotations = {}
relevant_annotation_indices = [ relevant_annotation_indices = [
i for i, x in enumerate(ann_info['name']) if x != 'DontCare' i for i, x in enumerate(ann_info['name']) if x != 'DontCare'
...@@ -176,6 +236,23 @@ class KittiDataset(Custom3DDataset): ...@@ -176,6 +236,23 @@ class KittiDataset(Custom3DDataset):
outputs, outputs,
pklfile_prefix=None, pklfile_prefix=None,
submission_prefix=None): submission_prefix=None):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submitted files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing
the json filepaths, tmp_dir is the temporal directory created
for saving json files when jsonfile_prefix is not specified.
"""
if pklfile_prefix is None: if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory() tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results') pklfile_prefix = osp.join(tmp_dir.name, 'results')
...@@ -390,7 +467,7 @@ class KittiDataset(Custom3DDataset): ...@@ -390,7 +467,7 @@ class KittiDataset(Custom3DDataset):
pklfile_prefix (str | None): The prefix of pkl file. pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file. submission_prefix (str | None): The prefix of submission file.
Return: Returns:
List[dict]: A list of dict have the kitti format List[dict]: A list of dict have the kitti format
""" """
assert len(net_outputs) == len(self.data_infos) assert len(net_outputs) == len(self.data_infos)
...@@ -553,6 +630,12 @@ class KittiDataset(Custom3DDataset): ...@@ -553,6 +630,12 @@ class KittiDataset(Custom3DDataset):
) )
def show(self, results, out_dir): def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
assert out_dir is not None, 'Expect out_dir, got none.' assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results): for i, result in enumerate(results):
example = self.prepare_test_data(i) example = self.prepare_test_data(i)
......
...@@ -139,6 +139,14 @@ class NuScenesDataset(Custom3DDataset): ...@@ -139,6 +139,14 @@ class NuScenesDataset(Custom3DDataset):
) )
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations sorted by timestamps.
"""
data = mmcv.load(ann_file) data = mmcv.load(ann_file)
data_infos = list(sorted(data['infos'], key=lambda e: e['timestamp'])) data_infos = list(sorted(data['infos'], key=lambda e: e['timestamp']))
data_infos = data_infos[::self.load_interval] data_infos = data_infos[::self.load_interval]
...@@ -147,6 +155,24 @@ class NuScenesDataset(Custom3DDataset): ...@@ -147,6 +155,24 @@ class NuScenesDataset(Custom3DDataset):
return data_infos return data_infos
def get_data_info(self, index): def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Standard input_dict consists of the
data information.
- sample_idx (str): sample index
- pts_filename (str): filename of point clouds
- sweeps (list[dict]): infos of sweeps
- timestamp (float): sample timestamp
- img_filename (str, optional): image filename
- lidar2img (list[np.ndarray], optional): transformations from
lidar to different cameras
- ann_info (dict): annotation info
"""
info = self.data_infos[index] info = self.data_infos[index]
# standard protocal modified from SECOND.Pytorch # standard protocal modified from SECOND.Pytorch
...@@ -188,6 +214,20 @@ class NuScenesDataset(Custom3DDataset): ...@@ -188,6 +214,20 @@ class NuScenesDataset(Custom3DDataset):
return input_dict return input_dict
def get_ann_info(self, index): def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``LiDARInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- gt_names (list[str]): class names of ground truths
"""
info = self.data_infos[index] info = self.data_infos[index]
# filter out bbox containing no points # filter out bbox containing no points
mask = info['num_lidar_pts'] > 0 mask = info['num_lidar_pts'] > 0
...@@ -221,6 +261,17 @@ class NuScenesDataset(Custom3DDataset): ...@@ -221,6 +261,17 @@ class NuScenesDataset(Custom3DDataset):
return anns_results return anns_results
def _format_bbox(self, results, jsonfile_prefix=None): def _format_bbox(self, results, jsonfile_prefix=None):
"""Convert the results to the standard format.
Args:
results (list[dict]): Testing results of the dataset.
jsonfile_prefix (str): The prefix of the output jsonfile.
You can specify the output directory/filename by
modifying the jsonfile_prefix. Default: None.
Returns:
str: Path of the output json file.
"""
nusc_annos = {} nusc_annos = {}
mapped_class_names = self.CLASSES mapped_class_names = self.CLASSES
...@@ -283,6 +334,19 @@ class NuScenesDataset(Custom3DDataset): ...@@ -283,6 +334,19 @@ class NuScenesDataset(Custom3DDataset):
logger=None, logger=None,
metric='bbox', metric='bbox',
result_name='pts_bbox'): result_name='pts_bbox'):
"""Evaluation for a single model in nuScenes protocol.
Args:
result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'.
result_name (str): Result name in the metric prefix.
Default: 'pts_bbox'.
Returns:
dict: Dictionary of evaluation details.
"""
from nuscenes import NuScenes from nuscenes import NuScenes
from nuscenes.eval.detection.evaluate import NuScenesEval from nuscenes.eval.detection.evaluate import NuScenesEval
...@@ -400,6 +464,12 @@ class NuScenesDataset(Custom3DDataset): ...@@ -400,6 +464,12 @@ class NuScenesDataset(Custom3DDataset):
return results_dict return results_dict
def show(self, results, out_dir): def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
for i, result in enumerate(results): for i, result in enumerate(results):
example = self.prepare_test_data(i) example = self.prepare_test_data(i)
points = example['points'][0]._data.numpy() points = example['points'][0]._data.numpy()
...@@ -421,6 +491,18 @@ class NuScenesDataset(Custom3DDataset): ...@@ -421,6 +491,18 @@ class NuScenesDataset(Custom3DDataset):
def output_to_nusc_box(detection): def output_to_nusc_box(detection):
"""Convert the output to the box class in the nuScenes.
Args:
detection (dict): Detection results.
- boxes_3d (:obj:``BaseInstance3DBoxes``): detection bbox
- scores_3d (torch.Tensor): detection scores
- labels_3d (torch.Tensor): predicted box labels
Returns:
list[:obj:``NuScenesBox``]: List of standard NuScenesBoxes.
"""
box3d = detection['boxes_3d'] box3d = detection['boxes_3d']
scores = detection['scores_3d'].numpy() scores = detection['scores_3d'].numpy()
labels = detection['labels_3d'].numpy() labels = detection['labels_3d'].numpy()
...@@ -456,6 +538,21 @@ def lidar_nusc_box_to_global(info, ...@@ -456,6 +538,21 @@ def lidar_nusc_box_to_global(info,
classes, classes,
eval_configs, eval_configs,
eval_version='detection_cvpr_2019'): eval_version='detection_cvpr_2019'):
"""Convert the box from ego to global coordinate.
Args:
info (dict): Info for a specific sample data, including the
calibration information.
boxes (list[:obj:``NuScenesBox``]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version.
Default: 'detection_cvpr_2019'
Returns:
list: List of standard NuScenesBoxes in the global
coordinate.
"""
box_list = [] box_list = []
for box in boxes: for box in boxes:
# Move box to ego vehicle coord system # Move box to ego vehicle coord system
......
...@@ -51,6 +51,16 @@ class BatchSampler: ...@@ -51,6 +51,16 @@ class BatchSampler:
@OBJECTSAMPLERS.register_module() @OBJECTSAMPLERS.register_module()
class DataBaseSampler(object): class DataBaseSampler(object):
"""Class for sampling data from the ground truth database.
Args:
info_path (str): Path of groundtruth database info.
data_root (str): Path of groundtruth database.
rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str]): List of classes. Default: None.
"""
def __init__(self, def __init__(self,
info_path, info_path,
...@@ -104,6 +114,15 @@ class DataBaseSampler(object): ...@@ -104,6 +114,15 @@ class DataBaseSampler(object):
@staticmethod @staticmethod
def filter_by_difficulty(db_infos, removed_difficulty): def filter_by_difficulty(db_infos, removed_difficulty):
"""Filter ground truths by difficulties.
Args:
db_infos (dict): Info of groundtruth database.
removed_difficulty (list): Difficulties that are not qualified.
Returns:
dict: Info of database after filtering.
"""
new_db_infos = {} new_db_infos = {}
for key, dinfos in db_infos.items(): for key, dinfos in db_infos.items():
new_db_infos[key] = [ new_db_infos[key] = [
...@@ -114,6 +133,16 @@ class DataBaseSampler(object): ...@@ -114,6 +133,16 @@ class DataBaseSampler(object):
@staticmethod @staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict): def filter_by_min_points(db_infos, min_gt_points_dict):
"""Filter ground truths by number of points in the bbox.
Args:
db_infos (dict): Info of groundtruth database.
min_gt_points_dict (dict): Different number of minimum points
needed for different categories of ground truths.
Returns:
dict: Info of database after filtering.
"""
for name, min_num in min_gt_points_dict.items(): for name, min_num in min_gt_points_dict.items():
min_num = int(min_num) min_num = int(min_num)
if min_num > 0: if min_num > 0:
...@@ -125,6 +154,22 @@ class DataBaseSampler(object): ...@@ -125,6 +154,22 @@ class DataBaseSampler(object):
return db_infos return db_infos
def sample_all(self, gt_bboxes, gt_labels, img=None): def sample_all(self, gt_bboxes, gt_labels, img=None):
"""Sampling all categories of bboxes.
Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Labels of boxes.
Returns:
dict: Dict of sampled 'pseudo ground truths'.
- gt_labels_3d (np.ndarray): labels of ground truths:
labels of sampled ground truths
- gt_bboxes_3d (:obj:``BaseInstance3DBoxes``):
sampled 3D bounding boxes
- points (np.ndarray): sampled points
- group_ids (np.ndarray): ids of sampled ground truths
"""
sampled_num_dict = {} sampled_num_dict = {}
sample_num_per_class = [] sample_num_per_class = []
for class_name, max_sample_num in zip(self.sample_classes, for class_name, max_sample_num in zip(self.sample_classes,
...@@ -198,6 +243,16 @@ class DataBaseSampler(object): ...@@ -198,6 +243,16 @@ class DataBaseSampler(object):
return ret return ret
def sample_class_v2(self, name, num, gt_bboxes): def sample_class_v2(self, name, num, gt_bboxes):
"""Sampling specific categories of bounding boxes.
Args:
name (str): Class of objects to be sampled.
num (int): Number of sampled bboxes.
gt_bboxes (np.ndarray): Ground truth boxes.
Returns:
list[dict]: Valid samples after collision test.
"""
sampled = self.sampler_dict[name].sample(num) sampled = self.sampler_dict[name].sample(num)
sampled = copy.deepcopy(sampled) sampled = copy.deepcopy(sampled)
num_gt = gt_bboxes.shape[0] num_gt = gt_bboxes.shape[0]
......
...@@ -114,7 +114,7 @@ class PointSegClassMapping(object): ...@@ -114,7 +114,7 @@ class PointSegClassMapping(object):
others as len(valid_cat_ids). others as len(valid_cat_ids).
Args: Args:
valid_cat_ids (tuple[int): A tuple of valid category. valid_cat_ids (tuple[int]): A tuple of valid category.
""" """
def __init__(self, valid_cat_ids): def __init__(self, valid_cat_ids):
......
...@@ -171,12 +171,12 @@ class ObjectNoise(object): ...@@ -171,12 +171,12 @@ class ObjectNoise(object):
"""Apply noise to each GT objects in the scene. """Apply noise to each GT objects in the scene.
Args: Args:
translation_std (list, optional): Standard deviation of the translation_std (list[float], optional): Standard deviation of the
distribution where translation noise are sampled from. distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25]. Defaults to [0.25, 0.25, 0.25].
global_rot_range (list, optional): Global rotation to the scene. global_rot_range (list[float], optional): Global rotation to the scene.
Defaults to [0.0, 0.0]. Defaults to [0.0, 0.0].
rot_range (list, optional): Object rotation range. rot_range (list[float], optional): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267]. Defaults to [-0.15707963267, 0.15707963267].
num_try (int, optional): Number of times to try if the noise applied is num_try (int, optional): Number of times to try if the noise applied is
invalid. Defaults to 100. invalid. Defaults to 100.
...@@ -429,8 +429,9 @@ class IndoorPointSample(object): ...@@ -429,8 +429,9 @@ class IndoorPointSample(object):
return_choices (bool): Whether return choice. return_choices (bool): Whether return choice.
Returns: Returns:
points (ndarray): 3D Points. tuple (np.ndarray, np.ndarray) | np.ndarray:
choices (ndarray): The generated random samples. points (np.ndarray): 3D Points.
choices (np.ndarray, optional): The generated random samples.
""" """
if replace is None: if replace is None:
replace = (points.shape[0] < num_samples) replace = (points.shape[0] < num_samples)
......
...@@ -64,6 +64,21 @@ class ScanNetDataset(Custom3DDataset): ...@@ -64,6 +64,21 @@ class ScanNetDataset(Custom3DDataset):
test_mode=test_mode) test_mode=test_mode)
def get_ann_info(self, index): def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``DepthInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- pts_instance_mask_path (str): path of instance masks
- pts_semantic_mask_path (str): path of semantic masks
"""
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index] info = self.data_infos[index]
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
...@@ -94,6 +109,12 @@ class ScanNetDataset(Custom3DDataset): ...@@ -94,6 +109,12 @@ class ScanNetDataset(Custom3DDataset):
return anns_results return anns_results
def show(self, results, out_dir): def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
assert out_dir is not None, 'Expect out_dir, got none.' assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results): for i, result in enumerate(results):
data_info = self.data_infos[i] data_info = self.data_infos[i]
......
...@@ -62,6 +62,21 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -62,6 +62,21 @@ class SUNRGBDDataset(Custom3DDataset):
test_mode=test_mode) test_mode=test_mode)
def get_ann_info(self, index): def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: Standard annotation dictionary
consists of the data information.
- gt_bboxes_3d (:obj:``DepthInstance3DBoxes``):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): labels of ground truths
- pts_instance_mask_path (str): path of instance masks
- pts_semantic_mask_path (str): path of semantic masks
"""
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index] info = self.data_infos[index]
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
...@@ -81,6 +96,12 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -81,6 +96,12 @@ class SUNRGBDDataset(Custom3DDataset):
return anns_results return anns_results
def show(self, results, out_dir): def show(self, results, out_dir):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
"""
assert out_dir is not None, 'Expect out_dir, got none.' assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results): for i, result in enumerate(results):
data_info = self.data_infos[i] data_info = self.data_infos[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment