Unverified Commit 278df1eb authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Feature] add mono3d inferencer (#2190)

* add mono3d inferencer

* update mono3d inferenceer

* update init file

* update unit test

* fix name

* add base_det3d_inferencer

* fix comments

* fix comments

* fix comments

* renmae pgd-kitti to pgd_kitti

* add parameterized in tests.txt

* add txt file

* update loadimgfromfilemono3d to fit latest mmcv
parent 4bd7aa18
...@@ -17,6 +17,8 @@ Collections: ...@@ -17,6 +17,8 @@ Collections:
Models: Models:
- Name: pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d - Name: pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d
Alias:
- pgd_kitti
In Collection: PGD In Collection: PGD
Config: configs/pgd/pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d.py Config: configs/pgd/pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d.py
Metadata: Metadata:
......
721.5377 0.0 609.5593 44.85728 0.0 721.5377 172.854 0.2163791 0.0 0.0 1.0 0.002745884 0.0 0.0 0.0 1.0
...@@ -3,12 +3,10 @@ from .inference import (convert_SyncBN, inference_detector, ...@@ -3,12 +3,10 @@ from .inference import (convert_SyncBN, inference_detector,
inference_mono_3d_detector, inference_mono_3d_detector,
inference_multi_modality_detector, inference_segmentor, inference_multi_modality_detector, inference_segmentor,
init_model) init_model)
from .inferencers import BaseDet3DInferencer, MonoDet3DInferencer
__all__ = [ __all__ = [
'inference_detector', 'inference_detector', 'init_model', 'inference_mono_3d_detector',
'init_model', 'convert_SyncBN', 'inference_multi_modality_detector',
'inference_mono_3d_detector', 'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer'
'convert_SyncBN',
'inference_multi_modality_detector',
'inference_segmentor',
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from .base_det3d_inferencer import BaseDet3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer
__all__ = ['BaseDet3DInferencer', 'MonoDet3DInferencer']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmengine
import numpy as np
import torch.nn as nn
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.runner import load_checkpoint
from mmengine.structures import InstanceData
from mmengine.visualization import Visualizer
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, register_all_modules
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
class BaseDet3DInferencer(BaseInferencer):
"""Base 3D object detection inferencer.
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"pgd-kitti" or
"configs/pgd/pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d.py".
If model is not specified, user must provide the
`weights` saved by MMEngine which contains the config string.
Defaults to None.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
palette: str = 'none') -> None:
self.palette = palette
register_all_modules()
super().__init__(
model=model, weights=weights, device=device, scope=scope)
def _convert_syncbn(self, cfg: ConfigType):
"""Convert config's naiveSyncBN to BN.
Args:
config (str or :obj:`mmengine.Config`): Config file path
or the config object.
"""
if isinstance(cfg, dict):
for item in cfg:
if item == 'norm_cfg':
cfg[item]['type'] = cfg[item]['type']. \
replace('naiveSyncBN', 'BN')
else:
self._convert_syncbn(cfg[item])
def _init_model(
self,
cfg: ConfigType,
weights: str,
device: str = 'cpu',
) -> nn.Module:
self._convert_syncbn(cfg.model)
cfg.model.train_cfg = None
model = MODELS.build(cfg.model)
checkpoint = load_checkpoint(model, weights, map_location='cpu')
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmdet3d 1.x
model.dataset_meta = checkpoint['meta']['dataset_meta']
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
else:
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': cfg.class_names}
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.cfg = cfg # save the config in the model for convenience
model.to(device)
model.eval()
return model
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
"""Returns the index of the transform in a pipeline.
If the transform is not found, returns -1.
"""
for i, transform in enumerate(pipeline_cfg):
if transform['type'] == name:
return i
return -1
def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
visualizer = super()._init_visualizer(cfg)
visualizer.dataset_meta = self.model.dataset_meta
return visualizer
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '',
print_result: bool = False,
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Inference batch size. Defaults to 1.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the visualization results in a
popup window. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
return super().__call__(
inputs,
return_datasamples,
batch_size,
return_vis=return_vis,
show=show,
wait_time=wait_time,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
img_out_dir=img_out_dir,
print_result=print_result,
pred_out_file=pred_out_file,
**kwargs)
def postprocess(
self,
preds: PredType,
visualization: Optional[List[np.ndarray]] = None,
return_datasample: bool = False,
print_result: bool = False,
pred_out_file: str = '',
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
Args:
preds (List[Dict]): Predictions of the model.
visualization (Optional[np.ndarray]): Visualized predictions.
return_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``.
- ``visualization`` (Any): Returned by :meth:`visualize`.
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
result_dict = {}
results = preds
if not return_datasample:
results = []
for pred in preds:
result = self.pred2dict(pred)
results.append(result)
result_dict['predictions'] = results
if print_result:
print(result_dict)
if pred_out_file != '':
mmengine.dump(result_dict, pred_out_file)
result_dict['visualization'] = visualization
return result_dict
def pred2dict(self, data_sample: InstanceData) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary.
It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.
"""
pred_instances = data_sample.pred_instances_3d.numpy()
result = {
'bboxes_3d': pred_instances.bboxes_3d.tensor.numpy().tolist(),
'labels_3d': pred_instances.labels_3d.tolist(),
'scores_3d': pred_instances.scores_3d.tolist()
}
return result
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData
from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
class MonoDet3DInferencer(BaseDet3DInferencer):
"""MMDet3D Monocular 3D object detection inferencer.
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"pgd_kitti" or
"configs/pgd/pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d.py".
If model is not specified, user must provide the
`weights` saved by MMEngine which contains the config string.
Defaults to None.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}
def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0
super(MonoDet3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)
def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.
Preprocess inputs to a list according to its type:
- list or tuple: return inputs
- dict:
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
Args:
inputs (Union[dict, list]): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
if isinstance(inputs, dict) and isinstance(inputs['img'], str):
img = inputs['img']
backend = get_file_backend(img)
if hasattr(backend, 'isdir') and isdir(img):
# Backends like HttpsBackend do not implement `isdir`, so only
# those backends that implement `isdir` could accept the inputs
# as a directory
filename_list = list_dir_or_file(img, list_dir=False)
img = [join_path(img, filename) for filename in filename_list]
inputs['img'] = img
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
return list(inputs)
def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFileMono3D')
if load_img_idx == -1:
raise ValueError(
'LoadImageFromFileMono3D is not found in the test pipeline')
pipeline_cfg[load_img_idx]['type'] = 'Mono3DInferencerLoader'
return Compose(pipeline_cfg)
def visualize(self,
inputs: InputsType,
preds: PredType,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions.
Args:
inputs (List[Dict]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if self.visualizer is None or (not show and img_out_dir == ''
and not return_vis):
return None
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')
results = []
for single_input, pred in zip(inputs, preds):
if isinstance(single_input['img'], str):
img_bytes = mmengine.fileio.get(single_input['img'])
img = mmcv.imfrombytes(img_bytes)
img = img[:, :, ::-1]
img_name = osp.basename(single_input['img'])
elif isinstance(single_input['img'], np.ndarray):
img = single_input['img'].copy()
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None
data_input = dict(img=img)
self.visualizer.add_datasample(
img_name,
data_input,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
out_file=out_file,
vis_task='mono_det',
)
results.append(img)
self.num_visualized_imgs += 1
return results
...@@ -4,7 +4,8 @@ from .formating import Pack3DDetInputs ...@@ -4,7 +4,8 @@ from .formating import Pack3DDetInputs
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromDict, LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromFile, LoadPointsFromMultiSweeps, LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, PointSegClassMapping) Mono3DInferencerLoader, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D from .test_time_aug import MultiScaleFlipAug3D
# yapf: disable # yapf: disable
from .transforms_3d import (AffineResize, BackgroundPointsFilter, from .transforms_3d import (AffineResize, BackgroundPointsFilter,
...@@ -28,5 +29,5 @@ __all__ = [ ...@@ -28,5 +29,5 @@ __all__ = [
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter',
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D' 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'Mono3DInferencerLoader'
] ]
...@@ -4,12 +4,14 @@ from typing import List, Optional, Union ...@@ -4,12 +4,14 @@ from typing import List, Optional, Union
import mmcv import mmcv
import mmengine import mmengine
import mmengine.fileio as fileio
import numpy as np import numpy as np
from mmcv.transforms import LoadImageFromFile from mmcv.transforms import LoadImageFromFile
from mmcv.transforms.base import BaseTransform from mmcv.transforms.base import BaseTransform
from mmdet.datasets.transforms import LoadAnnotations from mmdet.datasets.transforms import LoadAnnotations
from mmdet3d.registry import TRANSFORMS from mmdet3d.registry import TRANSFORMS
from mmdet3d.structures.bbox_3d import get_box_type
from mmdet3d.structures.points import BasePoints, get_points_type from mmdet3d.structures.points import BasePoints, get_points_type
...@@ -254,9 +256,21 @@ class LoadImageFromFileMono3D(LoadImageFromFile): ...@@ -254,9 +256,21 @@ class LoadImageFromFileMono3D(LoadImageFromFile):
'Currently we only support load image from kitti and' 'Currently we only support load image from kitti and'
'nuscenes datasets') 'nuscenes datasets')
img_bytes = self.file_client.get(filename) try:
img = mmcv.imfrombytes( if self.file_client_args is not None:
img_bytes, flag=self.color_type, backend=self.imdecode_backend) file_client = fileio.FileClient.infer_client(
self.file_client_args, filename)
img_bytes = file_client.get(filename)
else:
img_bytes = fileio.get(
filename, backend_args=self.backend_args)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
except Exception as e:
if self.ignore_empty:
return None
else:
raise e
if self.to_float32: if self.to_float32:
img = img.astype(np.float32) img = img.astype(np.float32)
...@@ -267,6 +281,46 @@ class LoadImageFromFileMono3D(LoadImageFromFile): ...@@ -267,6 +281,46 @@ class LoadImageFromFileMono3D(LoadImageFromFile):
return results return results
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def transform(self, results: dict) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img = results['img']
if self.to_float32:
img = img.astype(np.float32)
results['img_path'] = None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class LoadPointsFromMultiSweeps(BaseTransform): class LoadPointsFromMultiSweeps(BaseTransform):
"""Load points from multiple sweeps. """Load points from multiple sweeps.
...@@ -944,3 +998,67 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -944,3 +998,67 @@ class LoadAnnotations3D(LoadAnnotations):
repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, ' repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, '
repr_str += f'{indent_str}poly2mask={self.poly2mask})' repr_str += f'{indent_str}poly2mask={self.poly2mask})'
return repr_str return repr_str
@TRANSFORMS.register_module()
class Mono3DInferencerLoader(BaseTransform):
"""Load an image from ``results['images']['CAMX']['img']``. Similar with
:obj:`LoadImageFromFileMono3D`, but the image has been loaded as
:obj:`np.ndarray` in ``results['images']['CAMX']['img']``.
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
self.from_file = TRANSFORMS.build(
dict(type='LoadImageFromFileMono3D', **kwargs))
self.from_ndarray = TRANSFORMS.build(
dict(type='LoadImageFromNDArray', **kwargs))
def transform(self, single_input: dict) -> dict:
"""Transform function to add image meta information.
Args:
single_input (dict): Result dict with Webcam read image in
``results['images']['CAMX']['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
box_type_3d, box_mode_3d = get_box_type('camera')
if isinstance(single_input['calib'], str):
calib_path = single_input['calib']
with open(calib_path, 'r') as f:
lines = f.readlines()
cam2img = np.array([
float(info) for info in lines[0].split(' ')[0:16]
]).reshape([4, 4])
elif isinstance(single_input['calib'], np.ndarray):
cam2img = single_input['calib']
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
if isinstance(single_input['img'], str):
inputs = dict(
images=dict(
CAM_FRONT=dict(
img_path=single_input['img'], cam2img=cam2img)),
box_mode_3d=box_mode_3d,
box_type_3d=box_type_3d)
elif isinstance(single_input['img'], np.ndarray):
inputs = dict(
img=single_input['img'],
cam2img=cam2img,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d)
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
if 'img' in inputs:
return self.from_ndarray(inputs)
return self.from_file(inputs)
...@@ -729,6 +729,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -729,6 +729,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if 'gt_instances' in data_sample: if 'gt_instances' in data_sample:
if len(data_sample.gt_instances) > 0: if len(data_sample.gt_instances) > 0:
assert 'img' in data_input assert 'img' in data_input
img = data_input['img']
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy() img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
...@@ -760,6 +761,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -760,6 +761,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pred_instances = data_sample.pred_instances pred_instances = data_sample.pred_instances
pred_instances = pred_instances_3d[ pred_instances = pred_instances_3d[
pred_instances.scores > pred_score_thr].cpu() pred_instances.scores > pred_score_thr].cpu()
img = data_input['img']
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy() img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
......
...@@ -5,6 +5,7 @@ interrogate ...@@ -5,6 +5,7 @@ interrogate
isort isort
# Note: used for kwarray.group_items, this may be ported to mmcv in the future. # Note: used for kwarray.group_items, this may be ported to mmcv in the future.
kwarray kwarray
parameterized
pytest pytest
pytest-cov pytest-cov
pytest-runner pytest-runner
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import mmcv
import mmengine
import numpy as np
from mmengine.utils import is_list_of
from parameterized import parameterized
from mmdet3d.apis import MonoDet3DInferencer
from mmdet3d.structures import Det3DDataSample
class TestMonoDet3DInferencer(TestCase):
def test_init(self):
# init from metafile
MonoDet3DInferencer('pgd_kitti')
# init from cfg
MonoDet3DInferencer(
'configs/pgd/pgd_r101-caffe_fpn_head-gn_4xb3-4x_kitti-mono3d.py',
'https://download.openmmlab.com/mmdetection3d/v1.0.0_models/pgd/'
'pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d/'
'pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d_'
'20211022_102608-8a97533b.pth')
def assert_predictions_equal(self, preds1, preds2):
for pred1, pred2 in zip(preds1, preds2):
if 'bboxes_3d' in pred1:
self.assertTrue(
np.allclose(pred1['bboxes_3d'], pred2['bboxes_3d'], 0.1))
if 'scores_3d' in pred1:
self.assertTrue(
np.allclose(pred1['scores_3d'], pred2['scores_3d'], 0.1))
if 'labels_3d' in pred1:
self.assertTrue(
np.allclose(pred1['labels_3d'], pred2['labels_3d']))
@parameterized.expand(['pgd_kitti'])
def test_call(self, model):
# single img
img_path = 'demo/data/kitti/000008.png'
calib_path = 'demo/data/kitti/000008.txt'
inferencer = MonoDet3DInferencer(model)
inputs = dict(img=img_path, calib=calib_path)
res_path = inferencer(inputs, return_vis=True)
# ndarray
img = mmcv.imread(img_path)
inputs = dict(img=img, calib=calib_path)
res_ndarray = inferencer(inputs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
# multiple images
inputs = [
dict(
img='demo/data/kitti/000008.png',
calib='demo/data/kitti/000008.txt'),
dict(
img='demo/data/kitti/000008.png',
calib='demo/data/kitti/000008.txt')
]
res_path = inferencer(inputs, return_vis=True)
# list of ndarray
imgs = [mmcv.imread(p['img']) for p in inputs]
inputs[0]['img'] = imgs[0]
inputs[1]['img'] = imgs[1]
res_ndarray = inferencer(inputs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
@parameterized.expand(['pgd_kitti'])
def test_visualize(self, model):
inputs = [
dict(
img='demo/data/kitti/000008.png',
calib='demo/data/kitti/000008.txt'),
dict(
img='demo/data/kitti/000008.png',
calib='demo/data/kitti/000008.txt')
]
inferencer = MonoDet3DInferencer(model)
# img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir:
inferencer(inputs, img_out_dir=tmp_dir)
for img_dir in ['000008.png', '000008.png']:
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
@parameterized.expand(['pgd_kitti'])
def test_postprocess(self, model):
# return_datasample
img_path = 'demo/data/kitti/000008.png'
calib_path = 'demo/data/kitti/000008.txt'
inputs = dict(img=img_path, calib=calib_path)
inferencer = MonoDet3DInferencer(model)
res = inferencer(inputs, return_datasamples=True)
self.assertTrue(is_list_of(res['predictions'], Det3DDataSample))
# pred_out_file
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.json')
res = inferencer(
inputs, print_result=True, pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment