Commit edb6b369 authored by ChaimZhu's avatar ChaimZhu
Browse files

fix inference

parent 27a546e9
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import re import warnings
from copy import deepcopy from copy import deepcopy
from os import path as osp from os import path as osp
from typing import Sequence, Union
import mmcv import mmcv
import mmengine
import numpy as np import numpy as np
import torch import torch
from mmcv.parallel import collate, scatter import torch.nn as nn
from mmcv.runner import load_checkpoint from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint
from mmdet3d.core import Box3DMode from mmdet3d.core import Box3DMode, Det3DDataSample, SampleList
from mmdet3d.core.bbox import get_box_type from mmdet3d.core.bbox import get_box_type
from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import build_model from mmdet3d.models import build_model
from mmdet3d.utils import get_root_logger
def convert_SyncBN(config): def convert_SyncBN(config):
...@@ -37,7 +38,7 @@ def init_model(config, checkpoint=None, device='cuda:0'): ...@@ -37,7 +38,7 @@ def init_model(config, checkpoint=None, device='cuda:0'):
3D segmentor. 3D segmentor.
Args: Args:
config (str or :obj:`mmcv.Config`): Config file path or the config config (str or :obj:`mmengine.Config`): Config file path or the config
object. object.
checkpoint (str, optional): Checkpoint path. If left as None, the model checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights. will not load any weights.
...@@ -47,14 +48,14 @@ def init_model(config, checkpoint=None, device='cuda:0'): ...@@ -47,14 +48,14 @@ def init_model(config, checkpoint=None, device='cuda:0'):
nn.Module: The constructed detector. nn.Module: The constructed detector.
""" """
if isinstance(config, str): if isinstance(config, str):
config = mmcv.Config.fromfile(config) config = mmengine.Config.fromfile(config)
elif not isinstance(config, mmcv.Config): elif not isinstance(config, mmengine.Config):
raise TypeError('config must be a filename or Config object, ' raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}') f'but got {type(config)}')
config.model.pretrained = None config.model.pretrained = None
convert_SyncBN(config.model) convert_SyncBN(config.model)
config.model.train_cfg = None config.model.train_cfg = None
model = build_model(config.model, test_cfg=config.get('test_cfg')) model = build_model(config.model)
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint['meta']: if 'CLASSES' in checkpoint['meta']:
...@@ -67,256 +68,249 @@ def init_model(config, checkpoint=None, device='cuda:0'): ...@@ -67,256 +68,249 @@ def init_model(config, checkpoint=None, device='cuda:0'):
if device != 'cpu': if device != 'cpu':
torch.cuda.set_device(device) torch.cuda.set_device(device)
else: else:
logger = get_root_logger() warnings.warn('Don\'t suggest using CPU device. '
logger.warning('Don\'t suggest using CPU device. ' 'Some functions are not supported for now.')
'Some functions are not supported for now.')
model.to(device) model.to(device)
model.eval() model.eval()
return model return model
def inference_detector(model, pcd): PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def inference_detector(model: nn.Module,
pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
"""Inference point cloud with the detector. """Inference point cloud with the detector.
Args: Args:
model (nn.Module): The loaded detector. model (nn.Module): The loaded detector.
pcd (str): Point cloud files. pcds (str, ndarray, Sequence[str/ndarray]):
Either point cloud files or loaded point cloud.
Returns: Returns:
tuple: Predicted results and data from pipeline. :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
If pcds is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
""" """
if isinstance(pcds, (list, tuple)):
is_batch = True
else:
pcds = [pcds]
is_batch = False
cfg = model.cfg cfg = model.cfg
device = next(model.parameters()).device # model device
if not isinstance(pcd, str): if not isinstance(pcds[0], str):
cfg = cfg.copy() cfg = cfg.copy()
# set loading pipeline type # set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadPointsFromDict' cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) # box_type_3d, box_mode_3d = get_box_type(
# cfg.test_dataloader.dataset.box_type_3d)
data = []
for pcd in pcds:
# prepare data
if isinstance(pcd, str):
# load from point cloud file
data_ = dict(
lidar_points=dict(lidar_path=pcd),
# for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)),
sweeps=[],
# set timestamp = 0
timestamp=[0])
else:
# directly use loaded point cloud
data_ = dict(
points=pcd,
# for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)),
sweeps=[],
# set timestamp = 0
timestamp=[0])
data_ = test_pipeline(data_)
data.append(data_)
if isinstance(pcd, str):
# load from point clouds file
data = dict(
pts_filename=pcd,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d,
# for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)),
sweeps=[],
# set timestamp = 0
timestamp=[0],
img_fields=[],
bbox3d_fields=[],
pts_mask_fields=[],
pts_seg_fields=[],
bbox_fields=[],
mask_fields=[],
seg_fields=[])
else:
# load from http
data = dict(
points=pcd,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d,
# for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)),
sweeps=[],
# set timestamp = 0
timestamp=[0],
img_fields=[],
bbox3d_fields=[],
pts_mask_fields=[],
pts_seg_fields=[],
bbox_fields=[],
mask_fields=[],
seg_fields=[])
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device.index])[0]
else:
# this is a workaround to avoid the bug of MMDataParallel
data['img_metas'] = data['img_metas'][0].data
data['points'] = data['points'][0].data
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) results = model.test_step(data)
return result, data
if not is_batch:
return results[0]
else:
return results
def inference_multi_modality_detector(model, pcd, image, ann_file): def inference_multi_modality_detector(model: nn.Module,
pcds: Union[str, Sequence[str]],
imgs: Union[str, Sequence[str]],
ann_files: Union[str, Sequence[str]]):
"""Inference point cloud with the multi-modality detector. """Inference point cloud with the multi-modality detector.
Args: Args:
model (nn.Module): The loaded detector. model (nn.Module): The loaded detector.
pcd (str): Point cloud files. pcds (str, Sequence[str]):
image (str): Image files. Either point cloud files or loaded point cloud.
ann_file (str): Annotation files. imgs (str, Sequence[str]):
Either image files or loaded images.
ann_files (str, Sequence[str]): Annotation files.
Returns: Returns:
tuple: Predicted results and data from pipeline. :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
If pcds is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
""" """
# TODO: We will support
if isinstance(pcds, (list, tuple)):
is_batch = True
assert isinstance(imgs, (list, tuple))
assert isinstance(ann_files, (list, tuple))
assert len(pcds) == len(imgs) == len(ann_files)
else:
pcds = [pcds]
imgs = [imgs]
ann_files = [ann_files]
is_batch = False
cfg = model.cfg cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) box_type_3d, box_mode_3d = \
# get data info containing calib get_box_type(cfg.test_dataloader.dataset.box_type_3d)
data_infos = mmcv.load(ann_file)
image_idx = int(re.findall(r'\d+', image)[-1]) # xxx/sunrgbd_000017.jpg data = []
for x in data_infos: for index, pcd in enumerate(pcds):
if int(x['image']['image_idx']) != image_idx: # get data info containing calib
continue img = imgs[index]
info = x ann_file = ann_files[index]
break data_info = mmcv.load(ann_file)[0]
data = dict( # TODO: check the name consistency of
pts_filename=pcd, # image file and point cloud file
img_prefix=osp.dirname(image), data_ = dict(
img_info=dict(filename=osp.basename(image)), lidar_points=dict(lidar_path=pcd),
box_type_3d=box_type_3d, img_path=imgs[index],
box_mode_3d=box_mode_3d, img_prefix=osp.dirname(img),
img_fields=[], img_info=dict(filename=osp.basename(img)),
bbox3d_fields=[], box_type_3d=box_type_3d,
pts_mask_fields=[], box_mode_3d=box_mode_3d)
pts_seg_fields=[], data_ = test_pipeline(data_)
bbox_fields=[],
mask_fields=[], # LiDAR to image conversion for KITTI dataset
seg_fields=[]) if box_mode_3d == Box3DMode.LIDAR:
data = test_pipeline(data) data_['lidar2img'] = data_info['images']['CAM2']['lidar2img']
# Depth to image conversion for SUNRGBD dataset
# TODO: this code is dataset-specific. Move lidar2img and elif box_mode_3d == Box3DMode.DEPTH:
# depth2img to .pkl annotations in the future. data_['depth2img'] = data_info['images']['CAM0']['depth2img']
# LiDAR to image conversion
if box_mode_3d == Box3DMode.LIDAR: data.append(data_)
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
P2 = info['calib']['P2'].astype(np.float32)
lidar2img = P2 @ rect @ Trv2c
data['img_metas'][0].data['lidar2img'] = lidar2img
# Depth to image conversion
elif box_mode_3d == Box3DMode.DEPTH:
rt_mat = info['calib']['Rt']
# follow Coord3DMode.convert_point
rt_mat = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]
]) @ rt_mat.transpose(1, 0)
depth2img = info['calib']['K'] @ rt_mat
data['img_metas'][0].data['depth2img'] = depth2img
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device.index])[0]
else:
# this is a workaround to avoid the bug of MMDataParallel
data['img_metas'] = data['img_metas'][0].data
data['points'] = data['points'][0].data
data['img'] = data['img'][0].data
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) results = model.test_step(data)
return result, data
if not is_batch:
return results[0]
else:
return results
def inference_mono_3d_detector(model, image, ann_file): def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType,
ann_files: Union[str, Sequence[str]]):
"""Inference image with the monocular 3D detector. """Inference image with the monocular 3D detector.
Args: Args:
model (nn.Module): The loaded detector. model (nn.Module): The loaded detector.
image (str): Image files. imgs (str, Sequence[str]):
ann_file (str): Annotation files. Either image files or loaded images.
ann_files (str, Sequence[str]): Annotation files.
Returns: Returns:
tuple: Predicted results and data from pipeline. :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
If pcds is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
""" """
if isinstance(imgs, (list, tuple)):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = model.cfg cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) box_type_3d, box_mode_3d = \
# get data info containing calib get_box_type(cfg.test_dataloader.dataset.box_type_3d)
data_infos = mmcv.load(ann_file)
# find the info corresponding to this image data = []
for x in data_infos['images']: for index, img in enumerate(imgs):
if osp.basename(x['file_name']) != osp.basename(image): ann_file = ann_files[index]
continue # get data info containing calib
img_info = x data_info = mmcv.load(ann_file)[0]
break data_ = dict(
data = dict( img_path=img,
img_prefix=osp.dirname(image), images=data_info['images'],
img_info=dict(filename=osp.basename(image)), box_type_3d=box_type_3d,
box_type_3d=box_type_3d, box_mode_3d=box_mode_3d)
box_mode_3d=box_mode_3d,
img_fields=[], data_ = test_pipeline(data_)
bbox3d_fields=[],
pts_mask_fields=[],
pts_seg_fields=[],
bbox_fields=[],
mask_fields=[],
seg_fields=[])
# camera points to image conversion
if box_mode_3d == Box3DMode.CAM:
data['img_info'].update(dict(cam_intrinsic=img_info['cam_intrinsic']))
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device.index])[0]
else:
# this is a workaround to avoid the bug of MMDataParallel
data['img_metas'] = data['img_metas'][0].data
data['img'] = data['img'][0].data
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) results = model.test_step(data)
return result, data
if not is_batch:
return results[0]
else:
return results
def inference_segmentor(model, pcd):
def inference_segmentor(model: nn.Module, pcds: PointsType):
"""Inference point cloud with the segmentor. """Inference point cloud with the segmentor.
Args: Args:
model (nn.Module): The loaded segmentor. model (nn.Module): The loaded segmentor.
pcd (str): Point cloud files. pcds (str, Sequence[str]):
Either point cloud files or loaded point cloud.
Returns: Returns:
tuple: Predicted results and data from pipeline. :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
If pcds is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
""" """
if isinstance(pcds, (list, tuple)):
is_batch = True
else:
pcds = [pcds]
is_batch = False
cfg = model.cfg cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
data = dict(
pts_filename=pcd, data = []
img_fields=[], for pcd in pcds:
bbox3d_fields=[], data_ = dict(lidar_points=dict(lidar_path=pcd))
pts_mask_fields=[], data_ = test_pipeline(data_)
pts_seg_fields=[], data.append(data_)
bbox_fields=[],
mask_fields=[],
seg_fields=[])
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device.index])[0]
else:
# this is a workaround to avoid the bug of MMDataParallel
data['img_metas'] = data['img_metas'][0].data
data['points'] = data['points'][0].data
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) results = model.test_step(data)
return result, data
if not is_batch:
return results[0]
else:
return results
...@@ -513,7 +513,7 @@ class LoadPointsFromFile(BaseTransform): ...@@ -513,7 +513,7 @@ class LoadPointsFromFile(BaseTransform):
class LoadPointsFromDict(LoadPointsFromFile): class LoadPointsFromDict(LoadPointsFromFile):
"""Load Points From Dict.""" """Load Points From Dict."""
def __call__(self, results): def transform(self, results: dict) -> dict:
assert 'points' in results assert 'points' in results
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment