Unverified Commit ef13e5a2 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Add inferencer for lidar-based segmentation (#2304)

* add lidar_seg_inferencer

* fix random caused in slide_infernce

* Update semantickitti.py

* fix

* add BaseSeg3DInferencer

* refactor

* rename BaseDet3DInferencer to Base3DInferencer

* fix import error

* update doc
parent 06b56888
...@@ -68,6 +68,7 @@ Models: ...@@ -68,6 +68,7 @@ Models:
Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class_20210514_144009-24477ab1.pth Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class_20210514_144009-24477ab1.pth
- Name: pointnet2_ssg_2xb16-cosine-50e_s3dis-seg - Name: pointnet2_ssg_2xb16-cosine-50e_s3dis-seg
Alias: pointnet2-ssg_s3dis-seg
In Collection: PointNet++ In Collection: PointNet++
Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py
Metadata: Metadata:
......
...@@ -3,12 +3,12 @@ from .inference import (convert_SyncBN, inference_detector, ...@@ -3,12 +3,12 @@ from .inference import (convert_SyncBN, inference_detector,
inference_mono_3d_detector, inference_mono_3d_detector,
inference_multi_modality_detector, inference_segmentor, inference_multi_modality_detector, inference_segmentor,
init_model) init_model)
from .inferencers import (BaseDet3DInferencer, LidarDet3DInferencer, from .inferencers import (Base3DInferencer, LidarDet3DInferencer,
MonoDet3DInferencer) LidarSeg3DInferencer, MonoDet3DInferencer)
__all__ = [ __all__ = [
'inference_detector', 'init_model', 'inference_mono_3d_detector', 'inference_detector', 'init_model', 'inference_mono_3d_detector',
'convert_SyncBN', 'inference_multi_modality_detector', 'convert_SyncBN', 'inference_multi_modality_detector',
'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer', 'inference_segmentor', 'Base3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer' 'LidarDet3DInferencer', 'LidarSeg3DInferencer'
] ]
...@@ -76,16 +76,16 @@ def init_model(config: Union[str, Path, Config], ...@@ -76,16 +76,16 @@ def init_model(config: Union[str, Path, Config],
elif 'CLASSES' in checkpoint.get('meta', {}): elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x # < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES'] classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes} model.dataset_meta = {'classes': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE'] model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else: else:
# < mmdet3d 1.x # < mmdet3d 1.x
model.dataset_meta = {'CLASSES': config.class_names} model.dataset_meta = {'classes': config.class_names}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE'] model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
model.cfg = config # save the config in the model for convenience model.cfg = config # save the config in the model for convenience
if device != 'cpu': if device != 'cpu':
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_det3d_inferencer import BaseDet3DInferencer from .base_3d_inferencer import Base3DInferencer
from .lidar_det3d_inferencer import LidarDet3DInferencer from .lidar_det3d_inferencer import LidarDet3DInferencer
from .lidar_seg3d_inferencer import LidarSeg3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer from .mono_det3d_inferencer import MonoDet3DInferencer
__all__ = [ __all__ = [
'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer' 'Base3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer',
'LidarSeg3DInferencer'
] ]
...@@ -23,8 +23,8 @@ ImgType = Union[np.ndarray, Sequence[np.ndarray]] ...@@ -23,8 +23,8 @@ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
class BaseDet3DInferencer(BaseInferencer): class Base3DInferencer(BaseInferencer):
"""Base 3D object detection inferencer. """Base 3D model inferencer.
Args: Args:
model (str, optional): Path to the config file or the model name model (str, optional): Path to the config file or the model name
...@@ -39,7 +39,7 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -39,7 +39,7 @@ class BaseDet3DInferencer(BaseInferencer):
from metafile. Defaults to None. from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None. device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d. scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'. priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
...@@ -58,7 +58,7 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -58,7 +58,7 @@ class BaseDet3DInferencer(BaseInferencer):
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d', scope: str = 'mmdet3d',
palette: str = 'none') -> None: palette: str = 'none') -> None:
self.palette = palette self.palette = palette
init_default_scope(scope) init_default_scope(scope)
...@@ -97,16 +97,16 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -97,16 +97,16 @@ class BaseDet3DInferencer(BaseInferencer):
elif 'CLASSES' in checkpoint.get('meta', {}): elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x # < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES'] classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes} model.dataset_meta = {'classes': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE'] model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else: else:
# < mmdet3d 1.x # < mmdet3d 1.x
model.dataset_meta = {'CLASSES': cfg.class_names} model.dataset_meta = {'classes': cfg.class_names}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE'] model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
model.cfg = cfg # save the config in the model for convenience model.cfg = cfg # save the config in the model for convenience
model.to(device) model.to(device)
...@@ -130,8 +130,8 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -130,8 +130,8 @@ class BaseDet3DInferencer(BaseInferencer):
Args: Args:
inputs (Union[dict, list]): Inputs for the inferencer. inputs (Union[dict, list]): Inputs for the inferencer.
modality_key (Union[str, List[str]], optional): The key of the modality_key (Union[str, List[str]]): The key of the modality.
modality. Defaults to 'points'. Defaults to 'points'.
Returns: Returns:
list: List of input for the :meth:`preprocess`. list: List of input for the :meth:`preprocess`.
...@@ -187,6 +187,7 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -187,6 +187,7 @@ class BaseDet3DInferencer(BaseInferencer):
pred_out_file: str = '', pred_out_file: str = '',
**kwargs) -> dict: **kwargs) -> dict:
"""Call the inferencer. """Call the inferencer.
Args: Args:
inputs (InputsType): Inputs for the inferencer. inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as return_datasamples (bool): Whether to return results as
...@@ -205,7 +206,7 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -205,7 +206,7 @@ class BaseDet3DInferencer(BaseInferencer):
If left as empty, no file will be saved. Defaults to ''. If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False. visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved. visualization. If left as empty, no file will be saved.
Defaults to ''. Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`, **kwargs: Other keyword arguments passed to :meth:`preprocess`,
...@@ -213,6 +214,7 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -213,6 +214,7 @@ class BaseDet3DInferencer(BaseInferencer):
Each key in kwargs should be in the corresponding set of Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``. and ``postprocess_kwargs``.
Returns: Returns:
dict: Inference and visualization results. dict: Inference and visualization results.
""" """
...@@ -240,23 +242,30 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -240,23 +242,30 @@ class BaseDet3DInferencer(BaseInferencer):
) -> Union[ResType, Tuple[ResType, np.ndarray]]: ) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward`` """Process the predictions and visualization results from ``forward``
and ``visualize``. and ``visualize``.
This method should be responsible for the following tasks: This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed. 1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them. 2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions. 3. Dump or log the predictions.
Args: Args:
preds (List[Dict]): Predictions of the model. preds (List[Dict]): Predictions of the model.
visualization (Optional[np.ndarray]): Visualized predictions. visualization (np.ndarray, optional): Visualized predictions.
Defaults to None.
return_datasample (bool): Whether to use Datasample to store return_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used. inference results. If False, dict will be used.
Defaults to False.
print_result (bool): Whether to print the inference result w/o print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False. visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved. visualization. If left as empty, no file will be saved.
Defaults to ''. Defaults to ''.
Returns: Returns:
dict: Inference and visualization results with key ``predictions`` dict: Inference and visualization results with key ``predictions``
and ``visualization``. and ``visualization``.
- ``visualization`` (Any): Returned by :meth:`visualize`. - ``visualization`` (Any): Returned by :meth:`visualize`.
- ``predictions`` (dict or DataSample): Returned by - ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`. :meth:`forward` and processed in :meth:`postprocess`.
...@@ -286,11 +295,18 @@ class BaseDet3DInferencer(BaseInferencer): ...@@ -286,11 +295,18 @@ class BaseDet3DInferencer(BaseInferencer):
It's better to contain only basic data elements such as strings and It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable. numbers in order to guarantee it's json-serializable.
""" """
pred_instances = data_sample.pred_instances_3d.numpy() result = {}
if 'pred_instances_3d' in data_sample:
pred_instances_3d = data_sample.pred_instances_3d.numpy()
result = { result = {
'bboxes_3d': pred_instances.bboxes_3d.tensor.cpu().tolist(), 'bboxes_3d': pred_instances_3d.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances.labels_3d.tolist(), 'labels_3d': pred_instances_3d.labels_3d.tolist(),
'scores_3d': pred_instances.scores_3d.tolist() 'scores_3d': pred_instances_3d.scores_3d.tolist()
} }
if 'pred_pts_seg' in data_sample:
pred_pts_seg = data_sample.pred_pts_seg.numpy()
result['pts_semantic_mask'] = \
pred_pts_seg.pts_semantic_mask.tolist()
return result return result
...@@ -10,7 +10,7 @@ from mmengine.structures import InstanceData ...@@ -10,7 +10,7 @@ from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer from .base_3d_inferencer import Base3DInferencer
InstanceList = List[InstanceData] InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray] InputType = Union[str, np.ndarray]
...@@ -22,7 +22,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] ...@@ -22,7 +22,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
@INFERENCERS.register_module(name='det3d-lidar') @INFERENCERS.register_module(name='det3d-lidar')
@INFERENCERS.register_module() @INFERENCERS.register_module()
class LidarDet3DInferencer(BaseDet3DInferencer): class LidarDet3DInferencer(Base3DInferencer):
"""The inferencer of LiDAR-based detection. """The inferencer of LiDAR-based detection.
Args: Args:
...@@ -38,8 +38,9 @@ class LidarDet3DInferencer(BaseDet3DInferencer): ...@@ -38,8 +38,9 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
from metafile. Defaults to None. from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None. device will be automatically used. Defaults to None.
scope (str, optional): The scope of registry. scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str, optional): The palette of visualization. palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
preprocess_kwargs: set = set() preprocess_kwargs: set = set()
...@@ -56,14 +57,17 @@ class LidarDet3DInferencer(BaseDet3DInferencer): ...@@ -56,14 +57,17 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d', scope: str = 'mmdet3d',
palette: str = 'none') -> None: palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for # A global counter tracking the number of frames processed, for
# naming of the output results # naming of the output results
self.num_visualized_frames = 0 self.num_visualized_frames = 0
self.palette = palette super(LidarDet3DInferencer, self).__init__(
super().__init__( model=model,
model=model, weights=weights, device=device, scope=scope) weights=weights,
device=device,
scope=scope,
palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list: def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list. """Preprocess the inputs to a list.
...@@ -129,6 +133,7 @@ class LidarDet3DInferencer(BaseDet3DInferencer): ...@@ -129,6 +133,7 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
Defaults to 0.3. Defaults to 0.3.
img_out_dir (str): Output directory of visualization results. img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''. If left as empty, no file will be saved. Defaults to ''.
Returns: Returns:
List[np.ndarray] or None: Returns visualization results only if List[np.ndarray] or None: Returns visualization results only if
applicable. applicable.
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union
import mmengine
import numpy as np
from mmengine.dataset import Compose
from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_3d_inferencer import Base3DInferencer
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
@INFERENCERS.register_module(name='seg3d-lidar')
@INFERENCERS.register_module()
class LidarSeg3DInferencer(Base3DInferencer):
"""The inferencer of LiDAR-based segmentation.
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"pointnet2-ssg_s3dis-seg" or
"configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py".
If model is not specified, user must provide the
`weights` saved by MMEngine which contains the config string.
Defaults to None.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_visualized_frames = 0
super(LidarSeg3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.
Preprocess inputs to a list according to its type:
- list or tuple: return inputs
- dict: the value with key 'points' is
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
Args:
inputs (Union[dict, list]): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
return super()._inputs_to_list(inputs, modality_key='points')
def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Load annotation is also not applicable
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations3D')
if idx != -1:
del pipeline_cfg[idx]
idx = self._get_transform_idx(pipeline_cfg, 'PointSegClassMapping')
if idx != -1:
del pipeline_cfg[idx]
load_point_idx = self._get_transform_idx(pipeline_cfg,
'LoadPointsFromFile')
if load_point_idx == -1:
raise ValueError(
'LoadPointsFromFile is not found in the test pipeline')
load_cfg = pipeline_cfg[load_point_idx]
self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[
'load_dim']
self.use_dim = list(range(load_cfg['use_dim'])) if isinstance(
load_cfg['use_dim'], int) else load_cfg['use_dim']
pipeline_cfg[load_point_idx]['type'] = 'LidarDet3DInferencerLoader'
return Compose(pipeline_cfg)
def visualize(self,
inputs: InputsType,
preds: PredType,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions.
Args:
inputs (InputsType): Inputs for the inferencer.
preds (PredType): Predictions of the model.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if self.visualizer is None or (not show and img_out_dir == ''
and not return_vis):
return None
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')
results = []
for single_input, pred in zip(inputs, preds):
single_input = single_input['points']
if isinstance(single_input, str):
pts_bytes = mmengine.fileio.get(single_input)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
pc_name = osp.basename(single_input).split('.bin')[0]
pc_name = f'{pc_name}.png'
elif isinstance(single_input, np.ndarray):
points = single_input.copy()
pc_num = str(self.num_visualized_frames).zfill(8)
pc_name = f'pc_{pc_num}.png'
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
o3d_save_path = osp.join(img_out_dir, pc_name) \
if img_out_dir != '' else None
data_input = dict(points=points)
self.visualizer.add_datasample(
pc_name,
data_input,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
o3d_save_path=o3d_save_path,
vis_task='lidar_seg',
)
results.append(points)
self.num_visualized_frames += 1
return results
...@@ -11,7 +11,7 @@ from mmengine.structures import InstanceData ...@@ -11,7 +11,7 @@ from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer from .base_3d_inferencer import Base3DInferencer
InstanceList = List[InstanceData] InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray] InputType = Union[str, np.ndarray]
...@@ -23,7 +23,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] ...@@ -23,7 +23,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
@INFERENCERS.register_module(name='det3d-mono') @INFERENCERS.register_module(name='det3d-mono')
@INFERENCERS.register_module() @INFERENCERS.register_module()
class MonoDet3DInferencer(BaseDet3DInferencer): class MonoDet3DInferencer(Base3DInferencer):
"""MMDet3D Monocular 3D object detection inferencer. """MMDet3D Monocular 3D object detection inferencer.
Args: Args:
...@@ -39,7 +39,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer): ...@@ -39,7 +39,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer):
from metafile. Defaults to None. from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None. device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d. scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'. priority is palette -> config -> checkpoint. Defaults to 'none'.
""" """
...@@ -58,7 +58,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer): ...@@ -58,7 +58,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer):
model: Union[ModelType, str, None] = None, model: Union[ModelType, str, None] = None,
weights: Optional[str] = None, weights: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d', scope: str = 'mmdet3d',
palette: str = 'none') -> None: palette: str = 'none') -> None:
# A global counter tracking the number of images processed, for # A global counter tracking the number of images processed, for
# naming of the output images # naming of the output images
...@@ -127,6 +127,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer): ...@@ -127,6 +127,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer):
Defaults to 0.3. Defaults to 0.3.
img_out_dir (str): Output directory of visualization results. img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''. If left as empty, no file will be saved. Defaults to ''.
Returns: Returns:
List[np.ndarray] or None: Returns visualization results only if List[np.ndarray] or None: Returns visualization results only if
applicable. applicable.
......
...@@ -701,10 +701,37 @@ class LoadPointsFromDict(LoadPointsFromFile): ...@@ -701,10 +701,37 @@ class LoadPointsFromDict(LoadPointsFromFile):
dict: The processed results. dict: The processed results.
""" """
assert 'points' in results assert 'points' in results
points_class = get_points_type(self.coord_type)
points = results['points'] points = results['points']
results['points'] = points_class(
points, points_dim=points.shape[-1], attribute_dims=None) if self.norm_intensity:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
attribute_dims = None
if self.shift_height:
floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height
points = np.concatenate(
[points[:, :3],
np.expand_dims(height, 1), points[:, 3:]], 1)
attribute_dims = dict(height=3)
if self.use_color:
assert len(self.use_dim) >= 6
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(color=[
points.shape[1] - 3,
points.shape[1] - 2,
points.shape[1] - 1,
]))
points_class = get_points_type(self.coord_type)
points = points_class(
points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
results['points'] = points
return results return results
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import mmengine
import numpy as np
import torch
from mmengine.utils import is_list_of
from mmdet3d.apis import LidarSeg3DInferencer
from mmdet3d.structures import Det3DDataSample
class TestLiDARSeg3DInferencer(TestCase):
def setUp(self):
# init from alias
self.inferencer = LidarSeg3DInferencer('pointnet2-ssg_s3dis-seg')
def test_init(self):
# init from metafile
LidarSeg3DInferencer('pointnet2-ssg_s3dis-seg')
# init from cfg
LidarSeg3DInferencer(
'configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py',
'https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class_20210514_144205-995d0119.pth' # noqa
)
def assert_predictions_equal(self, preds1, preds2):
for pred1, pred2 in zip(preds1, preds2):
self.assertTrue(
np.allclose(pred1['pts_semantic_mask'],
pred2['pts_semantic_mask']))
def test_call(self):
if not torch.cuda.is_available():
return
# single point cloud
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
torch.manual_seed(0)
res_path = self.inferencer(inputs, return_vis=True)
# ndarray
pts_bytes = mmengine.fileio.get(inputs['points'])
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, 6)
inputs = dict(points=points)
torch.manual_seed(0)
res_ndarray = self.inferencer(inputs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
# multiple point clouds
inputs = [
dict(points='tests/data/s3dis/points/Area_1_office_2.bin'),
dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
]
torch.manual_seed(0)
res_path = self.inferencer(inputs, return_vis=True)
# list of ndarray
all_points = []
for p in inputs:
pts_bytes = mmengine.fileio.get(p['points'])
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, 6)
all_points.append(dict(points=points))
torch.manual_seed(0)
res_ndarray = self.inferencer(all_points, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
# point cloud dir, test different batch sizes
pc_dir = dict(points='tests/data/s3dis/points/')
res_bs2 = self.inferencer(pc_dir, batch_size=2, return_vis=True)
self.assertIn('visualization', res_bs2)
self.assertIn('predictions', res_bs2)
def test_visualizer(self):
if not torch.cuda.is_available():
return
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
# img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(inputs, img_out_dir=tmp_dir)
def test_post_processor(self):
if not torch.cuda.is_available():
return
# return_datasample
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
res = self.inferencer(inputs, return_datasamples=True)
self.assertTrue(is_list_of(res['predictions'], Det3DDataSample))
# pred_out_file
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.json')
res = self.inferencer(
inputs, print_result=True, pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment