Unverified Commit ef13e5a2 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Add inferencer for lidar-based segmentation (#2304)

* add lidar_seg_inferencer

* fix random caused in slide_infernce

* Update semantickitti.py

* fix

* add BaseSeg3DInferencer

* refactor

* rename BaseDet3DInferencer to Base3DInferencer

* fix import error

* update doc
parent 06b56888
......@@ -68,6 +68,7 @@ Models:
Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class_20210514_144009-24477ab1.pth
- Name: pointnet2_ssg_2xb16-cosine-50e_s3dis-seg
Alias: pointnet2-ssg_s3dis-seg
In Collection: PointNet++
Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py
Metadata:
......
......@@ -3,12 +3,12 @@ from .inference import (convert_SyncBN, inference_detector,
inference_mono_3d_detector,
inference_multi_modality_detector, inference_segmentor,
init_model)
from .inferencers import (BaseDet3DInferencer, LidarDet3DInferencer,
MonoDet3DInferencer)
from .inferencers import (Base3DInferencer, LidarDet3DInferencer,
LidarSeg3DInferencer, MonoDet3DInferencer)
__all__ = [
'inference_detector', 'init_model', 'inference_mono_3d_detector',
'convert_SyncBN', 'inference_multi_modality_detector',
'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer'
'inference_segmentor', 'Base3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer', 'LidarSeg3DInferencer'
]
......@@ -76,16 +76,16 @@ def init_model(config: Union[str, Path, Config],
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
model.dataset_meta = {'classes': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else:
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': config.class_names}
model.dataset_meta = {'classes': config.class_names}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
model.cfg = config # save the config in the model for convenience
if device != 'cpu':
......
# Copyright (c) OpenMMLab. All rights reserved.
from .base_det3d_inferencer import BaseDet3DInferencer
from .base_3d_inferencer import Base3DInferencer
from .lidar_det3d_inferencer import LidarDet3DInferencer
from .lidar_seg3d_inferencer import LidarSeg3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer
__all__ = [
'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer'
'Base3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer',
'LidarSeg3DInferencer'
]
......@@ -23,8 +23,8 @@ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
class BaseDet3DInferencer(BaseInferencer):
"""Base 3D object detection inferencer.
class Base3DInferencer(BaseInferencer):
"""Base 3D model inferencer.
Args:
model (str, optional): Path to the config file or the model name
......@@ -39,7 +39,7 @@ class BaseDet3DInferencer(BaseInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
......@@ -58,7 +58,7 @@ class BaseDet3DInferencer(BaseInferencer):
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
self.palette = palette
init_default_scope(scope)
......@@ -97,16 +97,16 @@ class BaseDet3DInferencer(BaseInferencer):
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
model.dataset_meta = {'classes': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else:
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': cfg.class_names}
model.dataset_meta = {'classes': cfg.class_names}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
model.cfg = cfg # save the config in the model for convenience
model.to(device)
......@@ -130,8 +130,8 @@ class BaseDet3DInferencer(BaseInferencer):
Args:
inputs (Union[dict, list]): Inputs for the inferencer.
modality_key (Union[str, List[str]], optional): The key of the
modality. Defaults to 'points'.
modality_key (Union[str, List[str]]): The key of the modality.
Defaults to 'points'.
Returns:
list: List of input for the :meth:`preprocess`.
......@@ -187,6 +187,7 @@ class BaseDet3DInferencer(BaseInferencer):
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
......@@ -205,7 +206,7 @@ class BaseDet3DInferencer(BaseInferencer):
If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
......@@ -213,6 +214,7 @@ class BaseDet3DInferencer(BaseInferencer):
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
......@@ -240,23 +242,30 @@ class BaseDet3DInferencer(BaseInferencer):
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
Args:
preds (List[Dict]): Predictions of the model.
visualization (Optional[np.ndarray]): Visualized predictions.
visualization (np.ndarray, optional): Visualized predictions.
Defaults to None.
return_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used.
Defaults to False.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``.
- ``visualization`` (Any): Returned by :meth:`visualize`.
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
......@@ -286,11 +295,18 @@ class BaseDet3DInferencer(BaseInferencer):
It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.
"""
pred_instances = data_sample.pred_instances_3d.numpy()
result = {}
if 'pred_instances_3d' in data_sample:
pred_instances_3d = data_sample.pred_instances_3d.numpy()
result = {
'bboxes_3d': pred_instances.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances.labels_3d.tolist(),
'scores_3d': pred_instances.scores_3d.tolist()
'bboxes_3d': pred_instances_3d.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances_3d.labels_3d.tolist(),
'scores_3d': pred_instances_3d.scores_3d.tolist()
}
if 'pred_pts_seg' in data_sample:
pred_pts_seg = data_sample.pred_pts_seg.numpy()
result['pts_semantic_mask'] = \
pred_pts_seg.pts_semantic_mask.tolist()
return result
......@@ -10,7 +10,7 @@ from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer
from .base_3d_inferencer import Base3DInferencer
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
......@@ -22,7 +22,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
@INFERENCERS.register_module(name='det3d-lidar')
@INFERENCERS.register_module()
class LidarDet3DInferencer(BaseDet3DInferencer):
class LidarDet3DInferencer(Base3DInferencer):
"""The inferencer of LiDAR-based detection.
Args:
......@@ -38,8 +38,9 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of registry.
palette (str, optional): The palette of visualization.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
preprocess_kwargs: set = set()
......@@ -56,14 +57,17 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_visualized_frames = 0
self.palette = palette
super().__init__(
model=model, weights=weights, device=device, scope=scope)
super(LidarDet3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.
......@@ -129,6 +133,7 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union
import mmengine
import numpy as np
from mmengine.dataset import Compose
from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_3d_inferencer import Base3DInferencer
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
@INFERENCERS.register_module(name='seg3d-lidar')
@INFERENCERS.register_module()
class LidarSeg3DInferencer(Base3DInferencer):
"""The inferencer of LiDAR-based segmentation.
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"pointnet2-ssg_s3dis-seg" or
"configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py".
If model is not specified, user must provide the
`weights` saved by MMEngine which contains the config string.
Defaults to None.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_visualized_frames = 0
super(LidarSeg3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.
Preprocess inputs to a list according to its type:
- list or tuple: return inputs
- dict: the value with key 'points' is
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
Args:
inputs (Union[dict, list]): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
return super()._inputs_to_list(inputs, modality_key='points')
def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Load annotation is also not applicable
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations3D')
if idx != -1:
del pipeline_cfg[idx]
idx = self._get_transform_idx(pipeline_cfg, 'PointSegClassMapping')
if idx != -1:
del pipeline_cfg[idx]
load_point_idx = self._get_transform_idx(pipeline_cfg,
'LoadPointsFromFile')
if load_point_idx == -1:
raise ValueError(
'LoadPointsFromFile is not found in the test pipeline')
load_cfg = pipeline_cfg[load_point_idx]
self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[
'load_dim']
self.use_dim = list(range(load_cfg['use_dim'])) if isinstance(
load_cfg['use_dim'], int) else load_cfg['use_dim']
pipeline_cfg[load_point_idx]['type'] = 'LidarDet3DInferencerLoader'
return Compose(pipeline_cfg)
def visualize(self,
inputs: InputsType,
preds: PredType,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions.
Args:
inputs (InputsType): Inputs for the inferencer.
preds (PredType): Predictions of the model.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if self.visualizer is None or (not show and img_out_dir == ''
and not return_vis):
return None
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')
results = []
for single_input, pred in zip(inputs, preds):
single_input = single_input['points']
if isinstance(single_input, str):
pts_bytes = mmengine.fileio.get(single_input)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
pc_name = osp.basename(single_input).split('.bin')[0]
pc_name = f'{pc_name}.png'
elif isinstance(single_input, np.ndarray):
points = single_input.copy()
pc_num = str(self.num_visualized_frames).zfill(8)
pc_name = f'pc_{pc_num}.png'
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
o3d_save_path = osp.join(img_out_dir, pc_name) \
if img_out_dir != '' else None
data_input = dict(points=points)
self.visualizer.add_datasample(
pc_name,
data_input,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
o3d_save_path=o3d_save_path,
vis_task='lidar_seg',
)
results.append(points)
self.num_visualized_frames += 1
return results
......@@ -11,7 +11,7 @@ from mmengine.structures import InstanceData
from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer
from .base_3d_inferencer import Base3DInferencer
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
......@@ -23,7 +23,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
@INFERENCERS.register_module(name='det3d-mono')
@INFERENCERS.register_module()
class MonoDet3DInferencer(BaseDet3DInferencer):
class MonoDet3DInferencer(Base3DInferencer):
"""MMDet3D Monocular 3D object detection inferencer.
Args:
......@@ -39,7 +39,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
......@@ -58,7 +58,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer):
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of images processed, for
# naming of the output images
......@@ -127,6 +127,7 @@ class MonoDet3DInferencer(BaseDet3DInferencer):
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
......
......@@ -701,10 +701,37 @@ class LoadPointsFromDict(LoadPointsFromFile):
dict: The processed results.
"""
assert 'points' in results
points_class = get_points_type(self.coord_type)
points = results['points']
results['points'] = points_class(
points, points_dim=points.shape[-1], attribute_dims=None)
if self.norm_intensity:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
attribute_dims = None
if self.shift_height:
floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height
points = np.concatenate(
[points[:, :3],
np.expand_dims(height, 1), points[:, 3:]], 1)
attribute_dims = dict(height=3)
if self.use_color:
assert len(self.use_dim) >= 6
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(color=[
points.shape[1] - 3,
points.shape[1] - 2,
points.shape[1] - 1,
]))
points_class = get_points_type(self.coord_type)
points = points_class(
points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
results['points'] = points
return results
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import mmengine
import numpy as np
import torch
from mmengine.utils import is_list_of
from mmdet3d.apis import LidarSeg3DInferencer
from mmdet3d.structures import Det3DDataSample
class TestLiDARSeg3DInferencer(TestCase):
def setUp(self):
# init from alias
self.inferencer = LidarSeg3DInferencer('pointnet2-ssg_s3dis-seg')
def test_init(self):
# init from metafile
LidarSeg3DInferencer('pointnet2-ssg_s3dis-seg')
# init from cfg
LidarSeg3DInferencer(
'configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py',
'https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class/pointnet2_ssg_16x2_cosine_50e_s3dis_seg-3d-13class_20210514_144205-995d0119.pth' # noqa
)
def assert_predictions_equal(self, preds1, preds2):
for pred1, pred2 in zip(preds1, preds2):
self.assertTrue(
np.allclose(pred1['pts_semantic_mask'],
pred2['pts_semantic_mask']))
def test_call(self):
if not torch.cuda.is_available():
return
# single point cloud
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
torch.manual_seed(0)
res_path = self.inferencer(inputs, return_vis=True)
# ndarray
pts_bytes = mmengine.fileio.get(inputs['points'])
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, 6)
inputs = dict(points=points)
torch.manual_seed(0)
res_ndarray = self.inferencer(inputs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
# multiple point clouds
inputs = [
dict(points='tests/data/s3dis/points/Area_1_office_2.bin'),
dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
]
torch.manual_seed(0)
res_path = self.inferencer(inputs, return_vis=True)
# list of ndarray
all_points = []
for p in inputs:
pts_bytes = mmengine.fileio.get(p['points'])
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, 6)
all_points.append(dict(points=points))
torch.manual_seed(0)
res_ndarray = self.inferencer(all_points, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
# point cloud dir, test different batch sizes
pc_dir = dict(points='tests/data/s3dis/points/')
res_bs2 = self.inferencer(pc_dir, batch_size=2, return_vis=True)
self.assertIn('visualization', res_bs2)
self.assertIn('predictions', res_bs2)
def test_visualizer(self):
if not torch.cuda.is_available():
return
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
# img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(inputs, img_out_dir=tmp_dir)
def test_post_processor(self):
if not torch.cuda.is_available():
return
# return_datasample
inputs = dict(points='tests/data/s3dis/points/Area_1_office_2.bin')
res = self.inferencer(inputs, return_datasamples=True)
self.assertTrue(is_list_of(res['predictions'], Det3DDataSample))
# pred_out_file
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.json')
res = self.inferencer(
inputs, print_result=True, pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment