# Copyright (c) OpenMMLab. All rights reserved. import re from copy import deepcopy from os import path as osp import mmcv import numpy as np import torch from mmcv.parallel import collate, scatter from mmcv.runner import load_checkpoint from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes, Coord3DMode, DepthInstance3DBoxes, LiDARInstance3DBoxes, show_multi_modality_result, show_result, show_seg_result) from mmdet3d.core.bbox import get_box_type from mmdet3d.datasets.pipelines import Compose from mmdet3d.models import build_model from mmdet3d.utils import get_root_logger def convert_SyncBN(config): """Convert config's naiveSyncBN to BN. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. """ if isinstance(config, dict): for item in config: if item == 'norm_cfg': config[item]['type'] = config[item]['type']. \ replace('naiveSyncBN', 'BN') else: convert_SyncBN(config[item]) def init_model(config, checkpoint=None, device='cuda:0'): """Initialize a model from config file, which could be a 3D detector or a 3D segmentor. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. device (str): Device to use. Returns: nn.Module: The constructed detector. """ if isinstance(config, str): config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') config.model.pretrained = None convert_SyncBN(config.model) config.model.train_cfg = None model = build_model(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if 'CLASSES' in checkpoint['meta']: model.CLASSES = checkpoint['meta']['CLASSES'] else: model.CLASSES = config.class_names if 'PALETTE' in checkpoint['meta']: # 3D Segmentor model.PALETTE = checkpoint['meta']['PALETTE'] model.cfg = config # save the config in the model for convenience if device != 'cpu': torch.cuda.set_device(device) else: logger = get_root_logger() logger.warning('Don\'t suggest using CPU device. ' 'Some functions are not supported for now.') model.to(device) model.eval() return model def inference_detector(model, pcd): """Inference point cloud with the detector. Args: model (nn.Module): The loaded detector. pcd (str): Point cloud files. Returns: tuple: Predicted results and data from pipeline. """ cfg = model.cfg device = next(model.parameters()).device # model device if not isinstance(pcd, str): cfg = cfg.copy() # set loading pipeline type cfg.data.test.pipeline[0].type = 'LoadPointsFromDict' # build the data pipeline test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = Compose(test_pipeline) box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) if isinstance(pcd, str): # load from point clouds file data = dict( pts_filename=pcd, box_type_3d=box_type_3d, box_mode_3d=box_mode_3d, # for ScanNet demo we need axis_align_matrix ann_info=dict(axis_align_matrix=np.eye(4)), sweeps=[], # set timestamp = 0 timestamp=[0], img_fields=[], bbox3d_fields=[], pts_mask_fields=[], pts_seg_fields=[], bbox_fields=[], mask_fields=[], seg_fields=[]) else: # load from http data = dict( points=pcd, box_type_3d=box_type_3d, box_mode_3d=box_mode_3d, # for ScanNet demo we need axis_align_matrix ann_info=dict(axis_align_matrix=np.eye(4)), sweeps=[], # set timestamp = 0 timestamp=[0], img_fields=[], bbox3d_fields=[], pts_mask_fields=[], pts_seg_fields=[], bbox_fields=[], mask_fields=[], seg_fields=[]) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device.index])[0] else: # this is a workaround to avoid the bug of MMDataParallel data['img_metas'] = data['img_metas'][0].data data['points'] = data['points'][0].data # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result, data def inference_multi_modality_detector(model, pcd, image, ann_file): """Inference point cloud with the multi-modality detector. Args: model (nn.Module): The loaded detector. pcd (str): Point cloud files. image (str): Image files. ann_file (str): Annotation files. Returns: tuple: Predicted results and data from pipeline. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = Compose(test_pipeline) box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) # get data info containing calib data_infos = mmcv.load(ann_file) image_idx = int(re.findall(r'\d+', image)[-1]) # xxx/sunrgbd_000017.jpg for x in data_infos: if int(x['image']['image_idx']) != image_idx: continue info = x break data = dict( pts_filename=pcd, img_prefix=osp.dirname(image), img_info=dict(filename=osp.basename(image)), box_type_3d=box_type_3d, box_mode_3d=box_mode_3d, img_fields=[], bbox3d_fields=[], pts_mask_fields=[], pts_seg_fields=[], bbox_fields=[], mask_fields=[], seg_fields=[]) data = test_pipeline(data) # TODO: this code is dataset-specific. Move lidar2img and # depth2img to .pkl annotations in the future. # LiDAR to image conversion if box_mode_3d == Box3DMode.LIDAR: rect = info['calib']['R0_rect'].astype(np.float32) Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32) P2 = info['calib']['P2'].astype(np.float32) lidar2img = P2 @ rect @ Trv2c data['img_metas'][0].data['lidar2img'] = lidar2img # Depth to image conversion elif box_mode_3d == Box3DMode.DEPTH: rt_mat = info['calib']['Rt'] # follow Coord3DMode.convert_point rt_mat = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0] ]) @ rt_mat.transpose(1, 0) depth2img = info['calib']['K'] @ rt_mat data['img_metas'][0].data['depth2img'] = depth2img data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device.index])[0] else: # this is a workaround to avoid the bug of MMDataParallel data['img_metas'] = data['img_metas'][0].data data['points'] = data['points'][0].data data['img'] = data['img'][0].data # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result, data def inference_mono_3d_detector(model, image, ann_file): """Inference image with the monocular 3D detector. Args: model (nn.Module): The loaded detector. image (str): Image files. ann_file (str): Annotation files. Returns: tuple: Predicted results and data from pipeline. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = Compose(test_pipeline) box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) # get data info containing calib data_infos = mmcv.load(ann_file) # find the info corresponding to this image for x in data_infos['images']: if osp.basename(x['file_name']) != osp.basename(image): continue img_info = x break data = dict( img_prefix=osp.dirname(image), img_info=dict(filename=osp.basename(image)), box_type_3d=box_type_3d, box_mode_3d=box_mode_3d, img_fields=[], bbox3d_fields=[], pts_mask_fields=[], pts_seg_fields=[], bbox_fields=[], mask_fields=[], seg_fields=[]) # camera points to image conversion if box_mode_3d == Box3DMode.CAM: data['img_info'].update(dict(cam_intrinsic=img_info['cam_intrinsic'])) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device.index])[0] else: # this is a workaround to avoid the bug of MMDataParallel data['img_metas'] = data['img_metas'][0].data data['img'] = data['img'][0].data # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result, data def inference_segmentor(model, pcd): """Inference point cloud with the segmentor. Args: model (nn.Module): The loaded segmentor. pcd (str): Point cloud files. Returns: tuple: Predicted results and data from pipeline. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = Compose(test_pipeline) data = dict( pts_filename=pcd, img_fields=[], bbox3d_fields=[], pts_mask_fields=[], pts_seg_fields=[], bbox_fields=[], mask_fields=[], seg_fields=[]) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device.index])[0] else: # this is a workaround to avoid the bug of MMDataParallel data['img_metas'] = data['img_metas'][0].data data['points'] = data['points'][0].data # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result, data def show_det_result_meshlab(data, result, out_dir, score_thr=0.0, show=False, snapshot=False): """Show 3D detection result by meshlab.""" points = data['points'][0][0].cpu().numpy() pts_filename = data['img_metas'][0][0]['pts_filename'] file_name = osp.split(pts_filename)[-1].split('.')[0] if 'pts_bbox' in result[0].keys(): pred_bboxes = result[0]['pts_bbox']['boxes_3d'].tensor.numpy() pred_scores = result[0]['pts_bbox']['scores_3d'].numpy() else: pred_bboxes = result[0]['boxes_3d'].tensor.numpy() pred_scores = result[0]['scores_3d'].numpy() # filter out low score bboxes for visualization if score_thr > 0: inds = pred_scores > score_thr pred_bboxes = pred_bboxes[inds] # for now we convert points into depth mode box_mode = data['img_metas'][0][0]['box_mode_3d'] if box_mode != Box3DMode.DEPTH: points = Coord3DMode.convert(points, box_mode, Coord3DMode.DEPTH) show_bboxes = Box3DMode.convert(pred_bboxes, box_mode, Box3DMode.DEPTH) else: show_bboxes = deepcopy(pred_bboxes) show_result( points, None, show_bboxes, out_dir, file_name, show=show, snapshot=snapshot) return file_name def show_seg_result_meshlab(data, result, out_dir, palette, show=False, snapshot=False): """Show 3D segmentation result by meshlab.""" points = data['points'][0][0].cpu().numpy() pts_filename = data['img_metas'][0][0]['pts_filename'] file_name = osp.split(pts_filename)[-1].split('.')[0] pred_seg = result[0]['semantic_mask'].numpy() if palette is None: # generate random color map max_idx = pred_seg.max() palette = np.random.randint(0, 256, size=(max_idx + 1, 3)) palette = np.array(palette).astype(np.int) show_seg_result( points, None, pred_seg, out_dir, file_name, palette=palette, show=show, snapshot=snapshot) return file_name def show_proj_det_result_meshlab(data, result, out_dir, score_thr=0.0, show=False, snapshot=False): """Show result of projecting 3D bbox to 2D image by meshlab.""" assert 'img' in data.keys(), 'image data is not provided for visualization' img_filename = data['img_metas'][0][0]['filename'] file_name = osp.split(img_filename)[-1].split('.')[0] # read from file because img in data_dict has undergone pipeline transform img = mmcv.imread(img_filename) if 'pts_bbox' in result[0].keys(): result[0] = result[0]['pts_bbox'] elif 'img_bbox' in result[0].keys(): result[0] = result[0]['img_bbox'] pred_bboxes = result[0]['boxes_3d'].tensor.numpy() pred_scores = result[0]['scores_3d'].numpy() # filter out low score bboxes for visualization if score_thr > 0: inds = pred_scores > score_thr pred_bboxes = pred_bboxes[inds] box_mode = data['img_metas'][0][0]['box_mode_3d'] if box_mode == Box3DMode.LIDAR: if 'lidar2img' not in data['img_metas'][0][0]: raise NotImplementedError( 'LiDAR to image transformation matrix is not provided') show_bboxes = LiDARInstance3DBoxes(pred_bboxes, origin=(0.5, 0.5, 0)) show_multi_modality_result( img, None, show_bboxes, data['img_metas'][0][0]['lidar2img'], out_dir, file_name, box_mode='lidar', show=show) elif box_mode == Box3DMode.DEPTH: show_bboxes = DepthInstance3DBoxes(pred_bboxes, origin=(0.5, 0.5, 0)) show_multi_modality_result( img, None, show_bboxes, None, out_dir, file_name, box_mode='depth', img_metas=data['img_metas'][0][0], show=show) elif box_mode == Box3DMode.CAM: if 'cam2img' not in data['img_metas'][0][0]: raise NotImplementedError( 'camera intrinsic matrix is not provided') show_bboxes = CameraInstance3DBoxes( pred_bboxes, box_dim=pred_bboxes.shape[-1], origin=(0.5, 1.0, 0.5)) show_multi_modality_result( img, None, show_bboxes, data['img_metas'][0][0]['cam2img'], out_dir, file_name, box_mode='camera', show=show) else: raise NotImplementedError( f'visualization of {box_mode} bbox is not supported') return file_name def show_result_meshlab(data, result, out_dir, score_thr=0.0, show=False, snapshot=False, task='det', palette=None): """Show result by meshlab. Args: data (dict): Contain data from pipeline. result (dict): Predicted result from model. out_dir (str): Directory to save visualized result. score_thr (float, optional): Minimum score of bboxes to be shown. Default: 0.0 show (bool, optional): Visualize the results online. Defaults to False. snapshot (bool, optional): Whether to save the online results. Defaults to False. task (str, optional): Distinguish which task result to visualize. Currently we support 3D detection, multi-modality detection and 3D segmentation. Defaults to 'det'. palette (list[list[int]]] | np.ndarray, optional): The palette of segmentation map. If None is given, random palette will be generated. Defaults to None. """ assert task in ['det', 'multi_modality-det', 'seg', 'mono-det'], \ f'unsupported visualization task {task}' assert out_dir is not None, 'Expect out_dir, got none.' if task in ['det', 'multi_modality-det']: file_name = show_det_result_meshlab(data, result, out_dir, score_thr, show, snapshot) if task in ['seg']: file_name = show_seg_result_meshlab(data, result, out_dir, palette, show, snapshot) if task in ['multi_modality-det', 'mono-det']: file_name = show_proj_det_result_meshlab(data, result, out_dir, score_thr, show, snapshot) return out_dir, file_name