Unverified Commit 9cb75e7d authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support ScanNet semantic segmentation dataset (#390)

* remove max_num_point in ScanNet data preprocessing

* add config file for ScanNet semantic segmentation dataset

* modify NormalizePointsColor in pipeline

* add visualization function for semantic segmentation

* add ignore_index to semantic segmentation visualization function

* add ignore_index to semantic segmentation evaluation function

* fix ignore_index bug in semantic segmentation evaluation function

* add test function to check ignore_index assignment in PointSegClassMapping

* fix slicing bug in BasePoints class and add unittest

* add IndoorPatchPointSample class for indoor semantic segmentation data loading and add unittest

* modify LoadPointsFromFile class and its unittest to support point color loading

* fix data path in unittest

* add setter function for coord and attributes of BasePoint and modify unittest

* modify color normalization operation to work on BasePoint class

* add unittest for ScanNet semantic segmentation data loading pipeline

* fix ignore_index bug in seg_eval function

* add ScanNet semantic segmentation dataset and unittest

* modify config file for ScanNet semantic segmentation

* fix visualization function and modify unittest

* fix a typo in seg_eval.py

* raise exception when semantic mask is not provided in train/eval data loading

* support custom computation of label weight for loss calculation

* modify seg_eval function to be more efficient

* fix small bugs & change variable names for clarity & add more cases to unittest

* move room index resampling and label weight computation to data pre-processing

* add option allowing user to determine whether to sub-sample point clouds

* fix typos & change .format to f-string & fix link in comment

* save all visualizations into .obj format for consistency

* infer num_classes from label2cat in eval_seg function

* add pre-computed room index and label weight for ScanNet dataset

* replace .ply with .obj in unittests and documents

* add TODO in case data is on ceph

* add base dataset for all semantic segmentation tasks & add ScanNet dataset inheriting from base dataset

* rename class for consistency

* fix minor typos in comment

* move Custom3DSegDataset to a new file

* modify BasePoint setter function to enable attribute adding

* add unittest for NormalizePointsColor and fix small bugs

* fix unittest for BasePoints

* modify ScanNet data pre-processing scripts

* change ignore_idx to -1 in seg_eval function

* remove sliding inference from PatchSample function and modify unittest

* remove PatchSample from scannet seg test_pipeline
parent d055876a
# dataset settings
dataset_type = 'ScanNetSegDataset'
data_root = './data/scannet/'
class_names = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink',
'bathtub', 'otherfurniture')
num_points = 8192
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)),
dict(
type='IndoorPatchPointSample',
num_points=num_points,
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points'])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_train.pkl',
pipeline=train_pipeline,
classes=class_names,
test_mode=False,
ignore_index=len(class_names),
scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy',
label_weight=data_root + 'seg_info/train_label_weight.npy'),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
ignore_index=len(class_names)),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
ignore_index=len(class_names)))
### Prepare ScanNet Data ### Prepare ScanNet Data for Indoor Detection or Segmentation Task
We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/). We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/).
1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory. 1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory.
2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. 2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points.
3. Enter the project root directory, generate training data by running 3. Enter the project root directory, generate training data by running
```bash ```bash
...@@ -33,6 +33,11 @@ scannet ...@@ -33,6 +33,11 @@ scannet
│ ├── xxxxx.bin │ ├── xxxxx.bin
├── semantic_mask ├── semantic_mask
│ ├── xxxxx.bin │ ├── xxxxx.bin
├── seg_info
│ ├── train_label_weight.npy
│ ├── train_resampled_scene_idxs.npy
│ ├── val_label_weight.npy
│ ├── val_resampled_scene_idxs.npy
├── scannet_infos_train.pkl ├── scannet_infos_train.pkl
├── scannet_infos_val.pkl ├── scannet_infos_val.pkl
......
...@@ -47,6 +47,7 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point, ...@@ -47,6 +47,7 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point,
instance_bboxes = instance_bboxes[bbox_mask, :] instance_bboxes = instance_bboxes[bbox_mask, :]
print(f'Num of care instances: {instance_bboxes.shape[0]}') print(f'Num of care instances: {instance_bboxes.shape[0]}')
if max_num_point is not None:
N = mesh_vertices.shape[0] N = mesh_vertices.shape[0]
if N > max_num_point: if N > max_num_point:
choices = np.random.choice(N, max_num_point, replace=False) choices = np.random.choice(N, max_num_point, replace=False)
...@@ -88,7 +89,7 @@ def main(): ...@@ -88,7 +89,7 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--max_num_point', '--max_num_point',
default=50000, default=None,
help='The maximum number of the points.') help='The maximum number of the points.')
parser.add_argument( parser.add_argument(
'--output_folder', '--output_folder',
......
...@@ -26,7 +26,7 @@ Optional arguments: ...@@ -26,7 +26,7 @@ Optional arguments:
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. - `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset. Typically we default to use official metrics for evaluation on different datasets, so it can be simply set to `mAP` as a placeholder, which applies to nuScenes, Lyft, ScanNet and SUNRGBD. For KITTI, if we only want to evaluate the 2D detection performance, we can simply set the metric to `img_bbox` (unstable, stay tuned). For Waymo, we provide both KITTI-style evaluation (unstable) and Waymo-style official protocol, corresponding to metric `kitti` and `waymo` respectively. We recommend to use the default official metric for stable performance and fair comparison with other methods. - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset. Typically we default to use official metrics for evaluation on different datasets, so it can be simply set to `mAP` as a placeholder, which applies to nuScenes, Lyft, ScanNet and SUNRGBD. For KITTI, if we only want to evaluate the 2D detection performance, we can simply set the metric to `img_bbox` (unstable, stay tuned). For Waymo, we provide both KITTI-style evaluation (unstable) and Waymo-style official protocol, corresponding to metric `kitti` and `waymo` respectively. We recommend to use the default official metric for stable performance and fair comparison with other methods.
- `--show`: If specified, detection results will be plotted in the silient mode. It is only applicable to single GPU testing and used for debugging and visualization. This should be used with `--show-dir`. - `--show`: If specified, detection results will be plotted in the silient mode. It is only applicable to single GPU testing and used for debugging and visualization. This should be used with `--show-dir`.
- `--show-dir`: If specified, detection results will be plotted on the `***_points.obj` and `***_pred.ply` files in the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option. - `--show-dir`: If specified, detection results will be plotted on the `***_points.obj` and `***_pred.obj` files in the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
Examples: Examples:
......
...@@ -57,7 +57,7 @@ To see the SUNRGBD, ScanNet or KITTI points and detection results, you can run t ...@@ -57,7 +57,7 @@ To see the SUNRGBD, ScanNet or KITTI points and detection results, you can run t
python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --show --show-dir ${SHOW_DIR} python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --show --show-dir ${SHOW_DIR}
``` ```
Aftering running this command, plotted results **_\_points.obj and _**\_pred.ply files in `${SHOW_DIR}`. Aftering running this command, plotted results **_\_points.obj and _**\_pred.obj files in `${SHOW_DIR}`.
To see the points, detection results and ground truth of SUNRGBD, ScanNet or KITTI during evaluation time, you can run the following command To see the points, detection results and ground truth of SUNRGBD, ScanNet or KITTI during evaluation time, you can run the following command
...@@ -65,7 +65,7 @@ To see the points, detection results and ground truth of SUNRGBD, ScanNet or KIT ...@@ -65,7 +65,7 @@ To see the points, detection results and ground truth of SUNRGBD, ScanNet or KIT
python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --eval 'mAP' --options 'show=True' 'out_dir=${SHOW_DIR}' python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --eval 'mAP' --options 'show=True' 'out_dir=${SHOW_DIR}'
``` ```
After running this command, you will obtain **_\_points.obj, _**\_pred.ply files and \*\*\*\_gt.ply in `${SHOW_DIR}`. When `show` is enabled, [Open3D](http://www.open3d.org/) will be used to visualize the results online. You need to set `show=False` while running test in remote server withou GUI. After running this command, you will obtain **_\_points.obj, _**\_pred.obj files and \*\*\*\_gt.obj in `${SHOW_DIR}`. When `show` is enabled, [Open3D](http://www.open3d.org/) will be used to visualize the results online. You need to set `show=False` while running test in remote server withou GUI.
As for offline visualization, you will have two options. As for offline visualization, you will have two options.
To visualize the results with `Open3D` backend, you can run the following command To visualize the results with `Open3D` backend, you can run the following command
...@@ -76,7 +76,7 @@ python tools/misc/visualize_results.py ${CONFIG_FILE} --result ${RESULTS_PATH} - ...@@ -76,7 +76,7 @@ python tools/misc/visualize_results.py ${CONFIG_FILE} --result ${RESULTS_PATH} -
![Open3D_visualization](../resources/open3d_visual.gif) ![Open3D_visualization](../resources/open3d_visual.gif)
Or you can use 3D visualization software such as the [MeshLab](http://www.meshlab.net/) to open the these files under `${SHOW_DIR}` to see the 3D detection output. Specifically, open `***_points.obj` to see the input point cloud and open `***_pred.ply` to see the predicted 3D bounding boxes. This allows the inference and results generation be done in remote server and the users can open them on their host with GUI. Or you can use 3D visualization software such as the [MeshLab](http://www.meshlab.net/) to open the these files under `${SHOW_DIR}` to see the 3D detection output. Specifically, open `***_points.obj` to see the input point cloud and open `***_pred.obj` to see the predicted 3D bounding boxes. This allows the inference and results generation be done in remote server and the users can open them on their host with GUI.
**Notice**: The visualization API is a little unstable since we plan to refactor these parts together with MMDetection in the future. **Notice**: The visualization API is a little unstable since we plan to refactor these parts together with MMDetection in the future.
......
...@@ -66,28 +66,37 @@ def get_acc_cls(hist): ...@@ -66,28 +66,37 @@ def get_acc_cls(hist):
return np.nanmean(np.diag(hist) / hist.sum(axis=1)) return np.nanmean(np.diag(hist) / hist.sum(axis=1))
def seg_eval(gt_labels, seg_preds, label2cat, logger=None): def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None):
"""Semantic Segmentation Evaluation. """Semantic Segmentation Evaluation.
Evaluate the result of the Semantic Segmentation. Evaluate the result of the Semantic Segmentation.
Args: Args:
gt_labels (list[torch.Tensor]): Ground truth labels. gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predtictions seg_preds (list[torch.Tensor]): Predictions.
label2cat (dict): Map from label to category. label2cat (dict): Map from label to category name.
ignore_index (int): Index that will be ignored in evaluation.
logger (logging.Logger | str | None): The way to print the mAP logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None. summary. See `mmdet.utils.print_log()` for details. Default: None.
Return: Returns:
dict[str, float]: Dict of results. dict[str, float]: Dict of results.
""" """
assert len(seg_preds) == len(gt_labels) assert len(seg_preds) == len(gt_labels)
num_classes = len(label2cat)
hist_list = [] hist_list = []
for i in range(len(seg_preds)): for i in range(len(gt_labels)):
hist_list.append( gt_seg = gt_labels[i].clone().numpy().astype(np.int)
fast_hist(seg_preds[i].numpy().astype(int), pred_seg = seg_preds[i].clone().numpy().astype(np.int)
gt_labels[i].numpy().astype(int), len(label2cat)))
# filter out ignored points
pred_seg[gt_seg == ignore_index] = -1
gt_seg[gt_seg == ignore_index] = -1
# calculate one instance result
hist_list.append(fast_hist(pred_seg, gt_seg, num_classes))
iou = per_class_iou(sum(hist_list)) iou = per_class_iou(sum(hist_list))
miou = np.nanmean(iou) miou = np.nanmean(iou)
acc = get_acc(sum(hist_list)) acc = get_acc(sum(hist_list))
......
import numpy as np import numpy as np
import torch import torch
import warnings
from abc import abstractmethod from abc import abstractmethod
...@@ -46,6 +47,17 @@ class BasePoints(object): ...@@ -46,6 +47,17 @@ class BasePoints(object):
"""torch.Tensor: Coordinates of each point with size (N, 3).""" """torch.Tensor: Coordinates of each point with size (N, 3)."""
return self.tensor[:, :3] return self.tensor[:, :3]
@coord.setter
def coord(self, tensor):
"""Set the coordinates of each point."""
try:
tensor = tensor.reshape(self.shape[0], 3)
except (RuntimeError, ValueError): # for torch.Tensor and np.ndarray
raise ValueError(f'got unexpected shape {tensor.shape}')
if not isinstance(tensor, torch.Tensor):
tensor = self.tensor.new_tensor(tensor)
self.tensor[:, :3] = tensor
@property @property
def height(self): def height(self):
"""torch.Tensor: A vector with height of each point.""" """torch.Tensor: A vector with height of each point."""
...@@ -55,6 +67,27 @@ class BasePoints(object): ...@@ -55,6 +67,27 @@ class BasePoints(object):
else: else:
return None return None
@height.setter
def height(self, tensor):
"""Set the height of each point."""
try:
tensor = tensor.reshape(self.shape[0])
except (RuntimeError, ValueError): # for torch.Tensor and np.ndarray
raise ValueError(f'got unexpected shape {tensor.shape}')
if not isinstance(tensor, torch.Tensor):
tensor = self.tensor.new_tensor(tensor)
if self.attribute_dims is not None and \
'height' in self.attribute_dims.keys():
self.tensor[:, self.attribute_dims['height']] = tensor
else:
# add height attribute
if self.attribute_dims is None:
self.attribute_dims = dict()
attr_dim = self.shape[1]
self.tensor = torch.cat([self.tensor, tensor.unsqueeze(1)], dim=1)
self.attribute_dims.update(dict(height=attr_dim))
self.points_dim += 1
@property @property
def color(self): def color(self):
"""torch.Tensor: A vector with color of each point.""" """torch.Tensor: A vector with color of each point."""
...@@ -64,6 +97,30 @@ class BasePoints(object): ...@@ -64,6 +97,30 @@ class BasePoints(object):
else: else:
return None return None
@color.setter
def color(self, tensor):
"""Set the color of each point."""
try:
tensor = tensor.reshape(self.shape[0], 3)
except (RuntimeError, ValueError): # for torch.Tensor and np.ndarray
raise ValueError(f'got unexpected shape {tensor.shape}')
if tensor.max() >= 256 or tensor.min() < 0:
warnings.warn('point got color value beyond [0, 255]')
if not isinstance(tensor, torch.Tensor):
tensor = self.tensor.new_tensor(tensor)
if self.attribute_dims is not None and \
'color' in self.attribute_dims.keys():
self.tensor[:, self.attribute_dims['color']] = tensor
else:
# add color attribute
if self.attribute_dims is None:
self.attribute_dims = dict()
attr_dim = self.shape[1]
self.tensor = torch.cat([self.tensor, tensor], dim=1)
self.attribute_dims.update(
dict(color=[attr_dim, attr_dim + 1, attr_dim + 2]))
self.points_dim += 3
@property @property
def shape(self): def shape(self):
"""torch.Shape: Shape of points.""" """torch.Shape: Shape of points."""
...@@ -136,8 +193,8 @@ class BasePoints(object): ...@@ -136,8 +193,8 @@ class BasePoints(object):
trans_vector.shape[1] == 3 trans_vector.shape[1] == 3
else: else:
raise NotImplementedError( raise NotImplementedError(
'Unsupported translation vector of shape {}'.format( f'Unsupported translation vector of shape {trans_vector.shape}'
trans_vector.shape)) )
self.tensor[:, :3] += trans_vector self.tensor[:, :3] += trans_vector
def in_range_3d(self, point_range): def in_range_3d(self, point_range):
...@@ -233,8 +290,8 @@ class BasePoints(object): ...@@ -233,8 +290,8 @@ class BasePoints(object):
elif isinstance(item, tuple) and len(item) == 2: elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice): if isinstance(item[1], slice):
start = 0 if item[1].start is None else item[1].start start = 0 if item[1].start is None else item[1].start
stop = self.tensor.shape[1] + \ stop = self.tensor.shape[1] if \
1 if item[1].stop is None else item[1].stop item[1].stop is None else item[1].stop
step = 1 if item[1].step is None else item[1].step step = 1 if item[1].step is None else item[1].step
item = list(item) item = list(item)
item[1] = list(range(start, stop, step)) item[1] = list(range(start, stop, step))
...@@ -246,9 +303,9 @@ class BasePoints(object): ...@@ -246,9 +303,9 @@ class BasePoints(object):
if self.attribute_dims is not None: if self.attribute_dims is not None:
attribute_dims = self.attribute_dims.copy() attribute_dims = self.attribute_dims.copy()
for key in self.attribute_dims.keys(): for key in self.attribute_dims.keys():
cur_attribute_dim = attribute_dims[key] cur_attribute_dims = attribute_dims[key]
if isinstance(cur_attribute_dim, int): if isinstance(cur_attribute_dims, int):
cur_attribute_dims = [cur_attribute_dim] cur_attribute_dims = [cur_attribute_dims]
intersect_attr = list( intersect_attr = list(
set(cur_attribute_dims).intersection(set(keep_dims))) set(cur_attribute_dims).intersection(set(keep_dims)))
if len(intersect_attr) == 1: if len(intersect_attr) == 1:
......
from .show_result import show_result from .show_result import show_result, show_seg_result
__all__ = ['show_result'] __all__ = ['show_result', 'show_seg_result']
...@@ -4,8 +4,8 @@ import trimesh ...@@ -4,8 +4,8 @@ import trimesh
from os import path as osp from os import path as osp
def _write_ply(points, out_filename): def _write_obj(points, out_filename):
"""Write points into ``ply`` format for meshlab visualization. """Write points into ``obj`` format for meshlab visualization.
Args: Args:
points (np.ndarray): Points in shape (N, dim). points (np.ndarray): Points in shape (N, dim).
...@@ -62,8 +62,8 @@ def _write_oriented_bbox(scene_bbox, out_filename): ...@@ -62,8 +62,8 @@ def _write_oriented_bbox(scene_bbox, out_filename):
scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box)) scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box))
mesh_list = trimesh.util.concatenate(scene.dump()) mesh_list = trimesh.util.concatenate(scene.dump())
# save to ply file # save to obj file
trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply') trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='obj')
return return
...@@ -93,7 +93,7 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True): ...@@ -93,7 +93,7 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True):
mmcv.mkdir_or_exist(result_path) mmcv.mkdir_or_exist(result_path)
if points is not None: if points is not None:
_write_ply(points, osp.join(result_path, f'{filename}_points.obj')) _write_obj(points, osp.join(result_path, f'{filename}_points.obj'))
if gt_bboxes is not None: if gt_bboxes is not None:
# bottom center to gravity center # bottom center to gravity center
...@@ -101,7 +101,7 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True): ...@@ -101,7 +101,7 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True):
# the positive direction for yaw in meshlab is clockwise # the positive direction for yaw in meshlab is clockwise
gt_bboxes[:, 6] *= -1 gt_bboxes[:, 6] *= -1
_write_oriented_bbox(gt_bboxes, _write_oriented_bbox(gt_bboxes,
osp.join(result_path, f'{filename}_gt.ply')) osp.join(result_path, f'{filename}_gt.obj'))
if pred_bboxes is not None: if pred_bboxes is not None:
# bottom center to gravity center # bottom center to gravity center
...@@ -109,4 +109,66 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True): ...@@ -109,4 +109,66 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True):
# the positive direction for yaw in meshlab is clockwise # the positive direction for yaw in meshlab is clockwise
pred_bboxes[:, 6] *= -1 pred_bboxes[:, 6] *= -1
_write_oriented_bbox(pred_bboxes, _write_oriented_bbox(pred_bboxes,
osp.join(result_path, f'{filename}_pred.ply')) osp.join(result_path, f'{filename}_pred.obj'))
def show_seg_result(points,
gt_seg,
pred_seg,
out_dir,
filename,
palette,
ignore_index=None,
show=False):
"""Convert results into format that is directly readable for meshlab.
Args:
points (np.ndarray): Points.
gt_seg (np.ndarray): Ground truth segmentation mask.
pred_seg (np.ndarray): Predicted segmentation mask.
out_dir (str): Path of output directory
filename (str): Filename of the current frame.
palette (np.ndarray): Mapping between class labels and colors.
ignore_index (int, optional): The label index to be ignored, e.g. \
unannotated points. Defaults to None.
show (bool, optional): Visualize the results online. Defaults to False.
"""
'''
# TODO: not sure how to draw colors online, maybe we need two frames?
from .open3d_vis import Visualizer
if show:
vis = Visualizer(points)
if pred_bboxes is not None:
vis.add_bboxes(bbox3d=pred_bboxes)
if gt_bboxes is not None:
vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1))
vis.show()
'''
# filter out ignored points
if gt_seg is not None and ignore_index is not None:
if points is not None:
points = points[gt_seg != ignore_index]
if pred_seg is not None:
pred_seg = pred_seg[gt_seg != ignore_index]
gt_seg = gt_seg[gt_seg != ignore_index]
if gt_seg is not None:
gt_seg_color = palette[gt_seg]
if pred_seg is not None:
pred_seg_color = palette[pred_seg]
result_path = osp.join(out_dir, filename)
mmcv.mkdir_or_exist(result_path)
if points is not None:
_write_obj(points, osp.join(result_path, f'{filename}_points.obj'))
if gt_seg is not None:
gt_seg = np.concatenate([points[:, :3], gt_seg_color], axis=1)
_write_obj(gt_seg, osp.join(result_path, f'{filename}_gt.obj'))
if pred_seg is not None:
pred_seg = np.concatenate([points[:, :3], pred_seg_color], axis=1)
_write_obj(pred_seg, osp.join(result_path, f'{filename}_pred.obj'))
from mmdet.datasets.builder import build_dataloader from mmdet.datasets.builder import build_dataloader
from .builder import DATASETS, build_dataset from .builder import DATASETS, build_dataset
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
from .lyft_dataset import LyftDataset from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset from .nuscenes_dataset import NuScenesDataset
...@@ -10,7 +11,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans, ...@@ -10,7 +11,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter, NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler) RandomFlip3D, VoxelBasedPointSampler)
from .scannet_dataset import ScanNetDataset from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset from .semantickitti_dataset import SemanticKITTIDataset
from .sunrgbd_dataset import SUNRGBDDataset from .sunrgbd_dataset import SUNRGBDDataset
from .waymo_dataset import WaymoDataset from .waymo_dataset import WaymoDataset
...@@ -23,6 +24,7 @@ __all__ = [ ...@@ -23,6 +24,7 @@ __all__ = [
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample', 'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset',
'SemanticKITTIDataset', 'Custom3DDataset', 'LoadPointsFromMultiSweeps', 'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler' 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler'
] ]
import mmcv
import numpy as np
import tempfile
import torch
from os import path as osp
from torch.utils.data import Dataset
from mmdet.datasets import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class Custom3DSegDataset(Dataset):
"""Customized 3D dataset for semantic segmentation task.
This is the base dataset of ScanNet and S3DIS dataset.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g. \
unannotated points. If None is given, set to len(self.CLASSES) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
label_weight (np.ndarray | str, optional): Precomputed weight to \
balance loss calculation. If None is given, use equal weighting.
Defaults to None.
"""
# names of all classes data used for the task
CLASSES = None
# class_ids used for training
VALID_CLASS_IDS = None
# all possible class_ids in loaded segmentation mask
ALL_CLASS_IDS = None
# official color for visualization
PALETTE = None
def __init__(self,
data_root,
ann_file,
pipeline=None,
classes=None,
palette=None,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
label_weight=None):
super().__init__()
self.data_root = data_root
self.ann_file = ann_file
self.test_mode = test_mode
self.modality = modality
self.data_infos = self.load_annotations(self.ann_file)
if pipeline is not None:
self.pipeline = Compose(pipeline)
self.ignore_index = len(self.CLASSES) if \
ignore_index is None else ignore_index
self.scene_idxs, self.label_weight = \
self.get_scene_idxs_and_label_weight(scene_idxs, label_weight)
self.CLASSES, self.PALETTE = \
self.get_classes_and_palette(classes, palette)
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
Args:
ann_file (str): Path of the annotation file.
Returns:
list[dict]: List of annotations.
"""
return mmcv.load(ann_file)
def get_data_info(self, index):
"""Get data info according to the given index.
Args:
index (int): Index of the sample data to get.
Returns:
dict: Data information that will be passed to the data \
preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path'])
input_dict = dict(
pts_filename=pts_filename,
sample_idx=sample_idx,
file_name=pts_filename)
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
return input_dict
def pre_pipeline(self, results):
"""Initialization before data preparation.
Args:
results (dict): Dict before data preprocessing.
- img_fields (list): Image fields.
- pts_mask_fields (list): Mask fields of points.
- pts_seg_fields (list): Mask fields of point segments.
- mask_fields (list): Fields of masks.
- seg_fields (list): Segment fields.
"""
results['img_fields'] = []
results['pts_mask_fields'] = []
results['pts_seg_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
def prepare_train_data(self, index):
"""Training data preparation.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Training data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
if input_dict is None:
return None
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
def prepare_test_data(self, index):
"""Prepare data for testing.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
This function is taken from MMSegmentation.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Defaults to None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Defaults to None.
"""
if classes is None:
self.custom_classes = False
# map id in the loaded mask to label used for training
self.label_map = {
cls_id: self.ignore_index
for cls_id in self.ALL_CLASS_IDS
}
self.label_map.update(
{cls_id: i
for i, cls_id in enumerate(self.VALID_CLASS_IDS)})
# map label to category name
self.label2cat = {
i: cat_name
for i, cat_name in enumerate(self.CLASSES)
}
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# update valid_class_ids
self.VALID_CLASS_IDS = [
self.VALID_CLASS_IDS[self.CLASSES.index(cls_name)]
for cls_name in class_names
]
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {
cls_id: self.ignore_index
for cls_id in self.ALL_CLASS_IDS
}
self.label_map.update(
{cls_id: i
for i, cls_id in enumerate(self.VALID_CLASS_IDS)})
self.label2cat = {
i: cat_name
for i, cat_name in enumerate(class_names)
}
# modify palette for visualization
palette = [
self.PALETTE[self.CLASSES.index(cls_name)]
for cls_name in class_names
]
# also need to modify self.label_weight
self.label_weight = np.array([
self.label_weight[self.CLASSES.index(cls_name)]
for cls_name in class_names
]).astype(np.float32)
return class_names, palette
def get_scene_idxs_and_label_weight(self, scene_idxs, label_weight):
"""Compute scene_idxs for data sampling and label weight for loss \
calculation.
We sample more times for scenes with more points. Label_weight is
inversely proportional to number of class points.
"""
if self.test_mode:
# when testing, we load one whole scene every time
# and we don't need label weight for loss calculation
return np.arange(len(self.data_infos)).astype(np.int32), \
np.ones(len(self.CLASSES)).astype(np.float32)
if scene_idxs is None:
scene_idxs = np.arange(len(self.data_infos))
if isinstance(scene_idxs, str):
scene_idxs = np.load(scene_idxs)
else:
scene_idxs = np.array(scene_idxs)
if label_weight is None:
# we don't used label weighting in training
label_weight = np.ones(len(self.CLASSES))
elif isinstance(label_weight, str):
label_weight = np.load(label_weight)
else:
label_weight = np.array(label_weight)
return scene_idxs.astype(np.int32), label_weight.astype(np.float32)
def format_results(self,
outputs,
pklfile_prefix=None,
submission_prefix=None):
"""Format the results to pkl file.
Args:
outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results, \
tmp_dir is the temporal directory created for saving json \
files when ``jsonfile_prefix`` is not specified.
"""
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
out = f'{pklfile_prefix}.pkl'
mmcv.dump(outputs, out)
return outputs, tmp_dir
def convert_to_label(self, mask):
"""Convert class_id in segmentation mask to label."""
# TODO: currently only support loading from local
# TODO: may need to consider ceph data storage in the future
if isinstance(mask, str):
if mask.endswith('npy'):
mask = np.load(mask)
else:
mask = np.fromfile(mask, dtype=np.long)
mask_copy = mask.copy()
for class_id, label in self.label_map.items():
mask_copy[mask == class_id] = label
return mask_copy
def evaluate(self,
results,
metric=None,
logger=None,
show=False,
out_dir=None):
"""Evaluate.
Evaluation in semantic segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
Returns:
dict: Evaluation results.
"""
from mmdet3d.core.evaluation import seg_eval
assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, 'Expect length of results > 0.'
assert len(results) == len(self.data_infos)
assert isinstance(
results[0], dict
), f'Expect elements in results to be dict, got {type(results[0])}.'
pred_sem_masks = [result['semantic_mask'] for result in results]
gt_sem_masks = [
torch.from_numpy(
self.convert_to_label(
osp.join(self.data_root,
data_info['pts_semantic_mask_path'])))
for data_info in self.data_infos
]
ret_dict = seg_eval(
gt_sem_masks,
pred_sem_masks,
self.label2cat,
self.ignore_index,
logger=logger)
if show:
self.show(pred_sem_masks, out_dir)
return ret_dict
def _rand_another(self, idx):
"""Randomly get another item with the same flag.
Returns:
int: Another index of item with the same flag.
"""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def __len__(self):
"""Return the length of scene_idxs.
Returns:
int: Length of data infos.
"""
return len(self.scene_idxs)
def __getitem__(self, idx):
"""Get item from infos according to the given index.
Returns:
dict: Data dictionary of the corresponding index.
"""
scene_idx = self.scene_idxs[idx] # map to scene idx
if self.test_mode:
return self.prepare_test_data(scene_idx)
while True:
data = self.prepare_train_data(scene_idx)
if data is None:
idx = self._rand_another(idx)
scene_idx = self.scene_idxs[idx] # map to scene idx
continue
return data
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
...@@ -6,9 +6,10 @@ from .loading import (LoadAnnotations3D, LoadMultiViewImageFromFiles, ...@@ -6,9 +6,10 @@ from .loading import (LoadAnnotations3D, LoadMultiViewImageFromFiles,
NormalizePointsColor, PointSegClassMapping) NormalizePointsColor, PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans, from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPointSample, ObjectNoise, ObjectRangeFilter, IndoorPatchPointSample, IndoorPointSample,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectNoise, ObjectRangeFilter, ObjectSample,
RandomFlip3D, VoxelBasedPointSampler) PointShuffle, PointsRangeFilter, RandomFlip3D,
VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -17,5 +18,6 @@ __all__ = [ ...@@ -17,5 +18,6 @@ __all__ = [
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler', 'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample', 'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler' 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'IndoorPatchPointSample'
] ]
...@@ -167,8 +167,8 @@ class Collect3D(object): ...@@ -167,8 +167,8 @@ class Collect3D(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
return self.__class__.__name__ + '(keys={}, meta_keys={})'.format( return self.__class__.__name__ + \
self.keys, self.meta_keys) f'(keys={self.keys}, meta_keys={self.meta_keys})'
@PIPELINES.register_module() @PIPELINES.register_module()
...@@ -256,7 +256,6 @@ class DefaultFormatBundle3D(DefaultFormatBundle): ...@@ -256,7 +256,6 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(class_names={}, '.format(self.class_names) repr_str += f'(class_names={self.class_names}, '
repr_str += 'with_gt={}, with_label={})'.format( repr_str += f'with_gt={self.with_gt}, with_label={self.with_label})'
self.with_gt, self.with_label)
return repr_str return repr_str
...@@ -61,8 +61,8 @@ class LoadMultiViewImageFromFiles(object): ...@@ -61,8 +61,8 @@ class LoadMultiViewImageFromFiles(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
return "{} (to_float32={}, color_type='{}')".format( return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\
self.__class__.__name__, self.to_float32, self.color_type) f"color_type='{self.color_type}')"
@PIPELINES.register_module() @PIPELINES.register_module()
...@@ -246,7 +246,7 @@ class PointSegClassMapping(object): ...@@ -246,7 +246,7 @@ class PointSegClassMapping(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(valid_cat_ids={})'.format(self.valid_cat_ids) repr_str += f'(valid_cat_ids={self.valid_cat_ids})'
return repr_str return repr_str
...@@ -274,16 +274,20 @@ class NormalizePointsColor(object): ...@@ -274,16 +274,20 @@ class NormalizePointsColor(object):
- points (np.ndarray): Points after color normalization. - points (np.ndarray): Points after color normalization.
""" """
points = results['points'] points = results['points']
assert points.shape[1] >= 6,\ assert points.attribute_dims is not None and \
f'Expect points have channel >=6, got {points.shape[1]}' 'color' in points.attribute_dims.keys(), \
points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0 'Expect points have color attribute'
if self.color_mean is not None:
points.color = points.color - \
points.color.new_tensor(self.color_mean)
points.color = points.color / 255.0
results['points'] = points results['points'] = points
return results return results
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(color_mean={})'.format(self.color_mean) repr_str += f'(color_mean={self.color_mean})'
return repr_str return repr_str
...@@ -294,17 +298,18 @@ class LoadPointsFromFile(object): ...@@ -294,17 +298,18 @@ class LoadPointsFromFile(object):
Load sunrgbd and scannet points from file. Load sunrgbd and scannet points from file.
Args: Args:
load_dim (int): The dimension of the loaded points.
Defaults to 6.
coord_type (str): The type of coordinates of points cloud. coord_type (str): The type of coordinates of points cloud.
Available options includes: Available options includes:
- 'LIDAR': Points in LiDAR coordinates. - 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset. - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates. - 'CAMERA': Points in camera coordinates.
load_dim (int): The dimension of the loaded points.
Defaults to 6.
use_dim (list[int]): Which dimensions of the points to be used. use_dim (list[int]): Which dimensions of the points to be used.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension. or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool): Whether to use shifted height. Defaults to False. shift_height (bool): Whether to use shifted height. Defaults to False.
use_color (bool): Whether to use color features. Defaults to False.
file_client_args (dict): Config dict of file clients, refer to file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
...@@ -315,8 +320,10 @@ class LoadPointsFromFile(object): ...@@ -315,8 +320,10 @@ class LoadPointsFromFile(object):
load_dim=6, load_dim=6,
use_dim=[0, 1, 2], use_dim=[0, 1, 2],
shift_height=False, shift_height=False,
use_color=False,
file_client_args=dict(backend='disk')): file_client_args=dict(backend='disk')):
self.shift_height = shift_height self.shift_height = shift_height
self.use_color = use_color
if isinstance(use_dim, int): if isinstance(use_dim, int):
use_dim = list(range(use_dim)) use_dim = list(range(use_dim))
assert max(use_dim) < load_dim, \ assert max(use_dim) < load_dim, \
...@@ -373,9 +380,22 @@ class LoadPointsFromFile(object): ...@@ -373,9 +380,22 @@ class LoadPointsFromFile(object):
if self.shift_height: if self.shift_height:
floor_height = np.percentile(points[:, 2], 0.99) floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height height = points[:, 2] - floor_height
points = np.concatenate([points, np.expand_dims(height, 1)], 1) points = np.concatenate(
[points[:, :3],
np.expand_dims(height, 1), points[:, 3:]], 1)
attribute_dims = dict(height=3) attribute_dims = dict(height=3)
if self.use_color:
assert len(self.use_dim) >= 6
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(color=[
points.shape[1] - 3,
points.shape[1] - 2,
points.shape[1] - 1,
]))
points_class = get_points_type(self.coord_type) points_class = get_points_type(self.coord_type)
points = points_class( points = points_class(
points, points_dim=points.shape[-1], attribute_dims=attribute_dims) points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
...@@ -386,10 +406,11 @@ class LoadPointsFromFile(object): ...@@ -386,10 +406,11 @@ class LoadPointsFromFile(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ + '(' repr_str = self.__class__.__name__ + '('
repr_str += 'shift_height={}, '.format(self.shift_height) repr_str += f'shift_height={self.shift_height}, '
repr_str += 'file_client_args={}), '.format(self.file_client_args) repr_str += f'use_color={self.use_color}, '
repr_str += 'load_dim={}, '.format(self.load_dim) repr_str += f'file_client_args={self.file_client_args}, '
repr_str += 'use_dim={})'.format(self.use_dim) repr_str += f'load_dim={self.load_dim}, '
repr_str += f'use_dim={self.use_dim})'
return repr_str return repr_str
......
...@@ -110,9 +110,8 @@ class RandomFlip3D(RandomFlip): ...@@ -110,9 +110,8 @@ class RandomFlip3D(RandomFlip):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(sync_2d={},'.format(self.sync_2d) repr_str += f'(sync_2d={self.sync_2d},'
repr_str += 'flip_ratio_bev_vertical={})'.format( repr_str += f'flip_ratio_bev_vertical={self.flip_ratio_bev_vertical})'
self.flip_ratio_bev_vertical)
return repr_str return repr_str
...@@ -278,10 +277,10 @@ class ObjectNoise(object): ...@@ -278,10 +277,10 @@ class ObjectNoise(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(num_try={},'.format(self.num_try) repr_str += f'(num_try={self.num_try},'
repr_str += ' translation_std={},'.format(self.translation_std) repr_str += f' translation_std={self.translation_std},'
repr_str += ' global_rot_range={},'.format(self.global_rot_range) repr_str += f' global_rot_range={self.global_rot_range},'
repr_str += ' rot_range={})'.format(self.rot_range) repr_str += f' rot_range={self.rot_range})'
return repr_str return repr_str
...@@ -427,10 +426,10 @@ class GlobalRotScaleTrans(object): ...@@ -427,10 +426,10 @@ class GlobalRotScaleTrans(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(rot_range={},'.format(self.rot_range) repr_str += f'(rot_range={self.rot_range},'
repr_str += ' scale_ratio_range={},'.format(self.scale_ratio_range) repr_str += f' scale_ratio_range={self.scale_ratio_range},'
repr_str += ' translation_std={})'.format(self.translation_std) repr_str += f' translation_std={self.translation_std},'
repr_str += ' shift_height={})'.format(self.shift_height) repr_str += f' shift_height={self.shift_height})'
return repr_str return repr_str
...@@ -497,7 +496,7 @@ class ObjectRangeFilter(object): ...@@ -497,7 +496,7 @@ class ObjectRangeFilter(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(point_cloud_range={})'.format(self.pcd_range.tolist()) repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
return repr_str return repr_str
...@@ -531,7 +530,7 @@ class PointsRangeFilter(object): ...@@ -531,7 +530,7 @@ class PointsRangeFilter(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(point_cloud_range={})'.format(self.pcd_range.tolist()) repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
return repr_str return repr_str
...@@ -646,7 +645,208 @@ class IndoorPointSample(object): ...@@ -646,7 +645,208 @@ class IndoorPointSample(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(num_points={})'.format(self.num_points) repr_str += f'(num_points={self.num_points})'
return repr_str
@PIPELINES.register_module()
class IndoorPatchPointSample(object):
r"""Indoor point sample within a patch. Modified from `PointNet++ <https://
github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_.
Sampling data to a certain number for semantic segmentation.
Args:
num_points (int): Number of points to be sampled.
block_size (float, optional): Size of a block to sample points from.
Defaults to 1.5.
sample_rate (float, optional): Stride used in sliding patch generation.
Defaults to 1.0.
ignore_index (int, optional): Label index that won't be used for the
segmentation task. This is set in PointSegClassMapping as neg_cls.
Defaults to None.
use_normalized_coord (bool, optional): Whether to use normalized xyz as
additional features. Defaults to False.
num_try (int, optional): Number of times to try if the patch selected
is invalid. Defaults to 10.
"""
def __init__(self,
num_points,
block_size=1.5,
sample_rate=1.0,
ignore_index=None,
use_normalized_coord=False,
num_try=10):
self.num_points = num_points
self.block_size = block_size
self.sample_rate = sample_rate
self.ignore_index = ignore_index
self.use_normalized_coord = use_normalized_coord
self.num_try = num_try
def _input_generation(self, coords, patch_center, coord_max, attributes,
attribute_dims, point_type):
"""Generating model input.
Generate input by subtracting patch center and adding additional \
features. Currently support colors and normalized xyz as features.
Args:
coords (np.ndarray): Sampled 3D Points.
patch_center (np.ndarray): Center coordinate of the selected patch.
coord_max (np.ndarray): Max coordinate of all 3D Points.
attributes (np.ndarray): features of input points.
attribute_dims (dict): Dictionary to indicate the meaning of extra
dimension.
point_type (type): class of input points.
Returns:
np.ndarray: The generated input data.
"""
# subtract patch center, the z dimension is not centered
centered_coords = coords.copy()
centered_coords[:, 0] -= patch_center[0]
centered_coords[:, 1] -= patch_center[1]
if self.use_normalized_coord:
normalized_coord = coords / coord_max
attributes = np.concatenate([attributes, normalized_coord], axis=1)
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(normalized_coord=[
attributes.shape[1], attributes.shape[1] +
1, attributes.shape[1] + 2
]))
points = np.concatenate([centered_coords, attributes], axis=1)
points = point_type(
points, points_dim=points.shape[1], attribute_dims=attribute_dims)
return points
def _patch_points_sampling(self, points, sem_mask, replace=None):
"""Patch points sampling.
First sample a valid patch.
Then sample points within that patch to a certain number.
Args:
points (BasePoints): 3D Points.
sem_mask (np.ndarray): semantic segmentation mask for input points.
replace (bool): Whether the sample is with or without replacement.
Defaults to None.
Returns:
tuple[np.ndarray] | np.ndarray:
- points (BasePoints): 3D Points.
- choices (np.ndarray): The generated random samples.
"""
coords = points.coord.numpy()
attributes = points.tensor[:, 3:].numpy()
attribute_dims = points.attribute_dims
point_type = type(points)
coord_max = np.amax(coords, axis=0)
coord_min = np.amin(coords, axis=0)
for i in range(self.num_try):
# random sample a point as patch center
cur_center = coords[np.random.choice(coords.shape[0])]
# boundary of a patch
cur_max = cur_center + np.array(
[self.block_size / 2.0, self.block_size / 2.0, 0.0])
cur_min = cur_center - np.array(
[self.block_size / 2.0, self.block_size / 2.0, 0.0])
cur_max[2] = coord_max[2]
cur_min[2] = coord_min[2]
cur_choice = np.sum(
(coords >= (cur_min - 0.2)) * (coords <= (cur_max + 0.2)),
axis=1) == 3
if not cur_choice.any(): # no points in this patch
continue
cur_coords = coords[cur_choice, :]
cur_sem_mask = sem_mask[cur_choice]
# two criterion for patch sampling, adopted from PointNet++
# points within selected patch shoule be scattered separately
mask = np.sum(
(cur_coords >= (cur_min - 0.01)) * (cur_coords <=
(cur_max + 0.01)),
axis=1) == 3
# not sure if 31, 31, 62 are just some big values used to transform
# coords from 3d array to 1d and then check their uniqueness
# this is used in all the ScanNet code following PointNet++
vidx = np.ceil((cur_coords[mask, :] - cur_min) /
(cur_max - cur_min) * np.array([31.0, 31.0, 62.0]))
vidx = np.unique(vidx[:, 0] * 31.0 * 62.0 + vidx[:, 1] * 62.0 +
vidx[:, 2])
flag1 = len(vidx) / 31.0 / 31.0 / 62.0 >= 0.02
# selected patch should contain enough annotated points
if self.ignore_index is None:
flag2 = True
else:
flag2 = np.sum(cur_sem_mask != self.ignore_index) / \
len(cur_sem_mask) >= 0.7
if flag1 and flag2:
break
# random sample idx
if replace is None:
replace = (cur_sem_mask.shape[0] < self.num_points)
choices = np.random.choice(
np.where(cur_choice)[0], self.num_points, replace=replace)
# construct model input
points = self._input_generation(coords[choices], cur_center, coord_max,
attributes[choices], attribute_dims,
point_type)
return points, choices
def __call__(self, results):
"""Call function to sample points to in indoor scenes.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points']
assert 'pts_semantic_mask' in results.keys(), \
'semantic mask should be provided in training and evaluation'
pts_semantic_mask = results['pts_semantic_mask']
points, choices = self._patch_points_sampling(points,
pts_semantic_mask)
results['points'] = points
results['pts_semantic_mask'] = pts_semantic_mask[choices]
pts_instance_mask = results.get('pts_instance_mask', None)
if pts_instance_mask is not None:
results['pts_instance_mask'] = pts_instance_mask[choices]
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points},'
repr_str += f' block_size={self.block_size},'
repr_str += f' sample_rate={self.sample_rate},'
repr_str += f' ignore_index={self.ignore_index},'
repr_str += f' use_normalized_coord={self.use_normalized_coord},'
repr_str += f' num_try={self.num_try})'
return repr_str return repr_str
...@@ -709,8 +909,7 @@ class BackgroundPointsFilter(object): ...@@ -709,8 +909,7 @@ class BackgroundPointsFilter(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(bbox_enlarge_range={})'.format( repr_str += f'(bbox_enlarge_range={self.bbox_enlarge_range.tolist()})'
self.bbox_enlarge_range.tolist())
return repr_str return repr_str
......
import numpy as np import numpy as np
from os import path as osp from os import path as osp
from mmdet3d.core import show_result from mmdet3d.core import show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset
@DATASETS.register_module() @DATASETS.register_module()
class ScanNetDataset(Custom3DDataset): class ScanNetDataset(Custom3DDataset):
r"""ScanNet Dataset. r"""ScanNet Dataset for Detection Task.
This class serves as the API for experiments on the ScanNet Dataset. This class serves as the API for experiments on the ScanNet Dataset.
...@@ -126,3 +127,153 @@ class ScanNetDataset(Custom3DDataset): ...@@ -126,3 +127,153 @@ class ScanNetDataset(Custom3DDataset):
pred_bboxes = result['boxes_3d'].tensor.numpy() pred_bboxes = result['boxes_3d'].tensor.numpy()
show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name, show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name,
show) show)
@DATASETS.register_module()
class ScanNetSegDataset(Custom3DSegDataset):
r"""ScanNet Dataset for Semantic Segmentation Task.
This class serves as the API for experiments on the ScanNet Dataset.
Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
for data downloading.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g. \
unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
label_weight (np.ndarray | str, optional): Precomputed weight to \
balance loss calculation. If None is given, compute from data.
Defaults to None.
"""
CLASSES = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink',
'bathtub', 'otherfurniture')
VALID_CLASS_IDS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)
ALL_CLASS_IDS = tuple(range(41))
PALETTE = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
def __init__(self,
data_root,
ann_file,
pipeline=None,
classes=None,
palette=None,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
label_weight=None):
super().__init__(
data_root=data_root,
ann_file=ann_file,
pipeline=pipeline,
classes=classes,
palette=palette,
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs,
label_weight=label_weight)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])
anns_results = dict(pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
def show(self, results, out_dir, show=True):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
for i, result in enumerate(results):
data_info = self.data_infos[i]
pts_path = data_info['pts_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points = np.fromfile(
osp.join(self.data_root, pts_path),
dtype=np.float32).reshape(-1, 6)
sem_mask_path = data_info['pts_semantic_mask_path']
gt_sem_mask = self.convert_to_label(
osp.join(self.data_root, sem_mask_path))
pred_sem_mask = result['semantic_mask'].numpy()
show_seg_result(points, gt_sem_mask,
pred_sem_mask, out_dir, file_name,
np.array(self.PALETTE), self.ignore_index, show)
def get_scene_idxs_and_label_weight(self, scene_idxs, label_weight):
"""Compute scene_idxs for data sampling and label weight for loss \
calculation.
We sample more times for scenes with more points. Label_weight is
inversely proportional to number of class points.
"""
# when testing, we load one whole scene every time
# and we don't need label weight for loss calculation
if not self.test_mode and scene_idxs is None:
raise NotImplementedError(
'please provide re-sampled scene indexes for training')
return super().get_scene_idxs_and_label_weight(scene_idxs,
label_weight)
...@@ -159,8 +159,8 @@ def test_show(): ...@@ -159,8 +159,8 @@ def test_show():
results = [result] results = [result]
kitti_dataset.show(results, temp_dir, show=False) kitti_dataset.show(results, temp_dir, show=False)
pts_file_path = osp.join(temp_dir, '000000', '000000_points.obj') pts_file_path = osp.join(temp_dir, '000000', '000000_points.obj')
gt_file_path = osp.join(temp_dir, '000000', '000000_gt.ply') gt_file_path = osp.join(temp_dir, '000000', '000000_gt.obj')
pred_file_path = osp.join(temp_dir, '000000', '000000_pred.ply') pred_file_path = osp.join(temp_dir, '000000', '000000_pred.obj')
mmcv.check_file_exist(pts_file_path) mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path) mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path) mmcv.check_file_exist(pred_file_path)
......
import copy
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from mmdet3d.datasets import ScanNetDataset from mmdet3d.datasets import ScanNetDataset, ScanNetSegDataset
def test_getitem(): def test_getitem():
...@@ -204,9 +205,321 @@ def test_show(): ...@@ -204,9 +205,321 @@ def test_show():
scannet_dataset.show(results, temp_dir, show=False) scannet_dataset.show(results, temp_dir, show=False)
pts_file_path = osp.join(temp_dir, 'scene0000_00', pts_file_path = osp.join(temp_dir, 'scene0000_00',
'scene0000_00_points.obj') 'scene0000_00_points.obj')
gt_file_path = osp.join(temp_dir, 'scene0000_00', 'scene0000_00_gt.ply') gt_file_path = osp.join(temp_dir, 'scene0000_00', 'scene0000_00_gt.obj')
pred_file_path = osp.join(temp_dir, 'scene0000_00', pred_file_path = osp.join(temp_dir, 'scene0000_00',
'scene0000_00_pred.ply') 'scene0000_00_pred.obj')
mmcv.check_file_exist(pts_file_path) mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path) mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path) mmcv.check_file_exist(pred_file_path)
def test_seg_getitem():
np.random.seed(0)
root_path = './tests/data/scannet/'
ann_file = './tests/data/scannet/scannet_infos.pkl'
class_names = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet',
'sink', 'bathtub', 'otherfurniture')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
scene_idxs = [0 for _ in range(20)]
label_weight = [
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071, 5.019636,
4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416, 5.3954206,
4.6971426
]
# test network inputs are (xyz, rgb, normalized_xyz)
pipelines = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
dict(
type='IndoorPatchPointSample',
num_points=5,
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=['points', 'pts_semantic_mask'],
meta_keys=['file_name', 'sample_idx'])
]
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=pipelines,
classes=None,
palette=None,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=scene_idxs,
label_weight=label_weight)
data = scannet_dataset[0]
points = data['points']._data
pts_semantic_mask = data['pts_semantic_mask']._data
file_name = data['img_metas']._data['file_name']
sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/scannet/points/scene0000_00.bin'
assert sample_idx == 'scene0000_00'
expected_points = torch.tensor([[
0.0000, 0.0000, 1.2427, 0.6118, 0.5529, 0.4471, -0.6462, -1.0046,
0.4280
],
[
0.1553, -0.0074, 1.6077, 0.5882,
0.6157, 0.5569, -0.6001, -1.0068,
0.5537
],
[
0.1518, 0.6016, 0.6548, 0.1490, 0.1059,
0.0431, -0.6012, -0.8309, 0.2255
],
[
-0.7494, 0.1033, 0.6756, 0.5216,
0.4353, 0.3333, -0.8687, -0.9748,
0.2327
],
[
-0.6836, -0.0203, 0.5884, 0.5765,
0.5020, 0.4510, -0.8491, -1.0105,
0.2027
]])
expected_pts_semantic_mask = np.array([13, 13, 12, 2, 0])
original_classes = scannet_dataset.CLASSES
original_palette = scannet_dataset.PALETTE
assert scannet_dataset.CLASSES == class_names
assert scannet_dataset.ignore_index == 20
assert torch.allclose(points, expected_points, 1e-2)
assert np.all(pts_semantic_mask.numpy() == expected_pts_semantic_mask)
assert original_classes == class_names
assert original_palette == palette
assert scannet_dataset.scene_idxs.dtype == np.int32
assert np.all(scannet_dataset.scene_idxs == np.array(scene_idxs))
assert np.allclose(scannet_dataset.label_weight, np.array(label_weight),
1e-5)
# test network inputs are (xyz, rgb)
np.random.seed(0)
new_pipelines = copy.deepcopy(pipelines)
new_pipelines[3] = dict(
type='IndoorPatchPointSample',
num_points=5,
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=False)
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=new_pipelines,
scene_idxs=scene_idxs)
data = scannet_dataset[0]
points = data['points']._data
assert torch.allclose(points, expected_points[:, :6], 1e-2)
# test network inputs are (xyz, normalized_xyz)
np.random.seed(0)
new_pipelines = copy.deepcopy(pipelines)
new_pipelines[0] = dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=False,
load_dim=6,
use_dim=[0, 1, 2])
new_pipelines.remove(new_pipelines[4])
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=new_pipelines,
scene_idxs=scene_idxs)
data = scannet_dataset[0]
points = data['points']._data
assert torch.allclose(points, expected_points[:, [0, 1, 2, 6, 7, 8]], 1e-2)
# test network inputs are (xyz,)
np.random.seed(0)
new_pipelines = copy.deepcopy(pipelines)
new_pipelines[0] = dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=False,
load_dim=6,
use_dim=[0, 1, 2])
new_pipelines[3] = dict(
type='IndoorPatchPointSample',
num_points=5,
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=False)
new_pipelines.remove(new_pipelines[4])
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=new_pipelines,
scene_idxs=scene_idxs)
data = scannet_dataset[0]
points = data['points']._data
assert torch.allclose(points, expected_points[:, :3], 1e-2)
# test dataset with selected classes
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=None,
classes=['cabinet', 'chair'],
scene_idxs=scene_idxs)
label_map = {i: 20 for i in range(41)}
label_map.update({3: 0, 5: 1})
assert scannet_dataset.CLASSES != original_classes
assert scannet_dataset.CLASSES == ['cabinet', 'chair']
assert scannet_dataset.PALETTE == [palette[2], palette[4]]
assert scannet_dataset.VALID_CLASS_IDS == [3, 5]
assert scannet_dataset.label_map == label_map
assert scannet_dataset.label2cat == {0: 'cabinet', 1: 'chair'}
assert np.all(scannet_dataset.label_weight == np.ones(2))
# test load classes from file
import tempfile
tmp_file = tempfile.NamedTemporaryFile()
with open(tmp_file.name, 'w') as f:
f.write('cabinet\nchair\n')
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=None,
classes=tmp_file.name,
scene_idxs=scene_idxs)
assert scannet_dataset.CLASSES != original_classes
assert scannet_dataset.CLASSES == ['cabinet', 'chair']
assert scannet_dataset.PALETTE == [palette[2], palette[4]]
assert scannet_dataset.VALID_CLASS_IDS == [3, 5]
assert scannet_dataset.label_map == label_map
assert scannet_dataset.label2cat == {0: 'cabinet', 1: 'chair'}
# test scene_idxs in dataset
# we should input scene_idxs in train mode
with pytest.raises(NotImplementedError):
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=None,
scene_idxs=None)
# test mode
scannet_dataset = ScanNetSegDataset(
data_root=root_path,
ann_file=ann_file,
pipeline=None,
test_mode=True,
scene_idxs=scene_idxs)
assert np.all(scannet_dataset.scene_idxs == np.array([0]))
assert np.all(scannet_dataset.label_weight == np.ones(len(class_names)))
def test_seg_evaluate():
if not torch.cuda.is_available():
pytest.skip()
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
scannet_dataset = ScanNetSegDataset(
data_root=root_path, ann_file=ann_file, test_mode=True)
results = []
pred_sem_mask = dict(
semantic_mask=torch.tensor([
13, 5, 1, 2, 6, 2, 13, 1, 14, 2, 0, 0, 5, 5, 3, 0, 1, 14, 0, 0, 0,
18, 6, 15, 13, 0, 2, 4, 0, 3, 16, 6, 13, 5, 13, 0, 0, 0, 0, 1, 7,
3, 19, 12, 8, 0, 11, 0, 0, 1, 2, 13, 17, 1, 1, 1, 6, 2, 13, 19, 4,
17, 0, 14, 1, 7, 2, 1, 7, 2, 0, 5, 17, 5, 0, 0, 3, 6, 5, 11, 1, 13,
13, 2, 3, 1, 0, 13, 19, 1, 14, 5, 3, 1, 13, 1, 2, 3, 2, 1
]).long())
results.append(pred_sem_mask)
ret_dict = scannet_dataset.evaluate(results)
assert abs(ret_dict['miou'] - 0.5308) < 0.01
assert abs(ret_dict['acc'] - 0.8219) < 0.01
assert abs(ret_dict['acc_cls'] - 0.7649) < 0.01
def test_seg_show():
import mmcv
import tempfile
from os import path as osp
tmp_dir = tempfile.TemporaryDirectory()
temp_dir = tmp_dir.name
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
scannet_dataset = ScanNetSegDataset(
data_root=root_path, ann_file=ann_file, scene_idxs=[0])
result = dict(
semantic_mask=torch.tensor([
13, 5, 1, 2, 6, 2, 13, 1, 14, 2, 0, 0, 5, 5, 3, 0, 1, 14, 0, 0, 0,
18, 6, 15, 13, 0, 2, 4, 0, 3, 16, 6, 13, 5, 13, 0, 0, 0, 0, 1, 7,
3, 19, 12, 8, 0, 11, 0, 0, 1, 2, 13, 17, 1, 1, 1, 6, 2, 13, 19, 4,
17, 0, 14, 1, 7, 2, 1, 7, 2, 0, 5, 17, 5, 0, 0, 3, 6, 5, 11, 1, 13,
13, 2, 3, 1, 0, 13, 19, 1, 14, 5, 3, 1, 13, 1, 2, 3, 2, 1
]).long())
results = [result]
scannet_dataset.show(results, temp_dir, show=False)
pts_file_path = osp.join(temp_dir, 'scene0000_00',
'scene0000_00_points.obj')
gt_file_path = osp.join(temp_dir, 'scene0000_00', 'scene0000_00_gt.obj')
pred_file_path = osp.join(temp_dir, 'scene0000_00',
'scene0000_00_pred.obj')
mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path)
tmp_dir.cleanup()
...@@ -147,8 +147,8 @@ def test_show(): ...@@ -147,8 +147,8 @@ def test_show():
results = [result] results = [result]
sunrgbd_dataset.show(results, temp_dir, show=False) sunrgbd_dataset.show(results, temp_dir, show=False)
pts_file_path = osp.join(temp_dir, '000001', '000001_points.obj') pts_file_path = osp.join(temp_dir, '000001', '000001_points.obj')
gt_file_path = osp.join(temp_dir, '000001', '000001_gt.ply') gt_file_path = osp.join(temp_dir, '000001', '000001_gt.obj')
pred_file_path = osp.join(temp_dir, '000001', '000001_pred.ply') pred_file_path = osp.join(temp_dir, '000001', '000001_pred.obj')
mmcv.check_file_exist(pts_file_path) mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path) mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path) mmcv.check_file_exist(pred_file_path)
...@@ -104,6 +104,77 @@ def test_scannet_pipeline(): ...@@ -104,6 +104,77 @@ def test_scannet_pipeline():
assert np.all(pts_instance_mask.numpy() == expected_pts_instance_mask) assert np.all(pts_instance_mask.numpy() == expected_pts_instance_mask)
def test_scannet_seg_pipeline():
class_names = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet',
'sink', 'bathtub', 'otherfurniture')
np.random.seed(0)
pipelines = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
dict(
type='IndoorPatchPointSample',
num_points=5,
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
pipeline = Compose(pipelines)
info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
results = dict()
data_path = './tests/data/scannet'
results['pts_filename'] = osp.join(data_path, info['pts_path'])
results['ann_info'] = dict()
results['ann_info']['pts_semantic_mask_path'] = osp.join(
data_path, info['pts_semantic_mask_path'])
results['pts_seg_fields'] = []
results = pipeline(results)
points = results['points']._data
pts_semantic_mask = results['pts_semantic_mask']._data
# build sampled points
scannet_points = np.fromfile(
osp.join(data_path, info['pts_path']), dtype=np.float32).reshape(
(-1, 6))
scannet_choices = np.array([87, 34, 58, 9, 18])
scannet_center = np.array([-2.1772466, -3.4789145, 1.242711])
scannet_center[2] = 0.0
scannet_coord_max = np.amax(scannet_points[:, :3], axis=0)
expected_points = np.concatenate([
scannet_points[scannet_choices, :3] - scannet_center,
scannet_points[scannet_choices, 3:] / 255.,
scannet_points[scannet_choices, :3] / scannet_coord_max
],
axis=1)
expected_pts_semantic_mask = np.array([13, 13, 12, 2, 0])
assert np.allclose(points.numpy(), expected_points, atol=1e-6)
assert np.all(pts_semantic_mask.numpy() == expected_pts_semantic_mask)
def test_sunrgbd_pipeline(): def test_sunrgbd_pipeline():
class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub') 'dresser', 'night_stand', 'bookshelf', 'bathtub')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment