Commit 3c5ff9fa authored by zhangwenwei's avatar zhangwenwei
Browse files

Support test time augmentation

parent f6e95edd
...@@ -3,12 +3,12 @@ from .custom_3d import Custom3DDataset ...@@ -3,12 +3,12 @@ from .custom_3d import Custom3DDataset
from .kitti2d_dataset import Kitti2DDataset from .kitti2d_dataset import Kitti2DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
from .nuscenes_dataset import NuScenesDataset from .nuscenes_dataset import NuScenesDataset
from .pipelines import (GlobalRotScale, IndoorFlipData, IndoorGlobalRotScale, from .pipelines import (GlobalRotScaleTrans, IndoorFlipData,
IndoorPointSample, IndoorPointsColorJitter, IndoorGlobalRotScaleTrans, IndoorPointSample,
LoadAnnotations3D, LoadPointsFromFile, IndoorPointsColorJitter, LoadAnnotations3D,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter, LoadPointsFromFile, NormalizePointsColor, ObjectNoise,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectRangeFilter, ObjectSample, PointShuffle,
RandomFlip3D) PointsRangeFilter, RandomFlip3D)
from .scannet_dataset import ScanNetDataset from .scannet_dataset import ScanNetDataset
from .sunrgbd_dataset import SUNRGBDDataset from .sunrgbd_dataset import SUNRGBDDataset
...@@ -17,9 +17,10 @@ __all__ = [ ...@@ -17,9 +17,10 @@ __all__ = [
'build_dataloader', 'RepeatFactorDataset', 'DATASETS', 'build_dataset', 'build_dataloader', 'RepeatFactorDataset', 'DATASETS', 'build_dataset',
'build_dataloader' 'build_dataloader'
'CocoDataset', 'Kitti2DDataset', 'NuScenesDataset', 'ObjectSample', 'CocoDataset', 'Kitti2DDataset', 'NuScenesDataset', 'ObjectSample',
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScale', 'PointShuffle', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle',
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample', 'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'IndoorPointsColorJitter', 'IndoorGlobalRotScale', 'LoadAnnotations3D', 'IndoorPointsColorJitter',
'IndoorFlipData', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset' 'IndoorGlobalRotScaleTrans', 'IndoorFlipData', 'SUNRGBDDataset',
'ScanNetDataset', 'Custom3DDataset'
] ]
...@@ -103,9 +103,13 @@ class Custom3DDataset(Dataset): ...@@ -103,9 +103,13 @@ class Custom3DDataset(Dataset):
return input_dict return input_dict
def pre_pipeline(self, results): def pre_pipeline(self, results):
results['img_fields'] = []
results['bbox3d_fields'] = [] results['bbox3d_fields'] = []
results['pts_mask_fields'] = [] results['pts_mask_fields'] = []
results['pts_seg_fields'] = [] results['pts_seg_fields'] = []
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
results['box_type_3d'] = self.box_type_3d results['box_type_3d'] = self.box_type_3d
results['box_mode_3d'] = self.box_mode_3d results['box_mode_3d'] = self.box_mode_3d
......
from mmdet.datasets.pipelines import Compose from mmdet.datasets.pipelines import Compose
from .dbsampler import DataBaseSampler, MMDataBaseSampler from .dbsampler import DataBaseSampler, MMDataBaseSampler
from .formating import DefaultFormatBundle, DefaultFormatBundle3D from .formating import DefaultFormatBundle, DefaultFormatBundle3D
from .indoor_augment import (IndoorFlipData, IndoorGlobalRotScale, from .indoor_augment import (IndoorFlipData, IndoorGlobalRotScaleTrans,
IndoorPointsColorJitter) IndoorPointsColorJitter)
from .indoor_loading import (LoadAnnotations3D, LoadPointsFromFile, from .indoor_loading import (LoadAnnotations3D, LoadPointsFromFile,
NormalizePointsColor) NormalizePointsColor)
from .indoor_sample import IndoorPointSample from .indoor_sample import IndoorPointSample
from .loading import LoadMultiViewImageFromFiles from .loading import LoadMultiViewImageFromFiles
from .point_seg_class_mapping import PointSegClassMapping from .point_seg_class_mapping import PointSegClassMapping
from .train_aug import (GlobalRotScale, ObjectNoise, ObjectRangeFilter, from .test_time_aug import MultiScaleFlipAug3D
ObjectSample, PointShuffle, PointsRangeFilter, from .transforms_3d import (GlobalRotScaleTrans, ObjectNoise,
RandomFlip3D) ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScale', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile', 'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler', 'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'IndoorGlobalRotScale', 'IndoorPointsColorJitter', 'IndoorFlipData', 'IndoorGlobalRotScaleTrans', 'IndoorPointsColorJitter', 'IndoorFlipData',
'MMDataBaseSampler', 'NormalizePointsColor', 'LoadAnnotations3D', 'MMDataBaseSampler', 'NormalizePointsColor', 'LoadAnnotations3D',
'IndoorPointSample', 'PointSegClassMapping' 'IndoorPointSample', 'PointSegClassMapping', 'MultiScaleFlipAug3D'
] ]
...@@ -83,12 +83,12 @@ class Collect3D(object): ...@@ -83,12 +83,12 @@ class Collect3D(object):
def __call__(self, results): def __call__(self, results):
data = {} data = {}
img_meta = {} img_metas = {}
for key in self.meta_keys: for key in self.meta_keys:
if key in results: if key in results:
img_meta[key] = results[key] img_metas[key] = results[key]
data['img_meta'] = DC(img_meta, cpu_only=True) data['img_metas'] = DC(img_metas, cpu_only=True)
for key in self.keys: for key in self.keys:
data[key] = results[key] data[key] = results[key]
return data return data
......
...@@ -117,7 +117,7 @@ class IndoorPointsColorJitter(object): ...@@ -117,7 +117,7 @@ class IndoorPointsColorJitter(object):
# TODO: merge outdoor indoor transform. # TODO: merge outdoor indoor transform.
# TODO: try transform noise. # TODO: try transform noise.
@PIPELINES.register_module() @PIPELINES.register_module()
class IndoorGlobalRotScale(object): class IndoorGlobalRotScaleTrans(object):
"""Indoor global rotate and scale. """Indoor global rotate and scale.
Augment sunrgbd and scannet data with global rotating and scaling. Augment sunrgbd and scannet data with global rotating and scaling.
......
...@@ -158,7 +158,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -158,7 +158,7 @@ class LoadAnnotations3D(LoadAnnotations):
def _load_bboxes_3d(self, results): def _load_bboxes_3d(self, results):
results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d'] results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d']
results['bbox3d_fields'].append(results['gt_bboxes_3d']) results['bbox3d_fields'].append('gt_bboxes_3d')
return results return results
def _load_labels_3d(self, results): def _load_labels_3d(self, results):
...@@ -179,7 +179,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -179,7 +179,7 @@ class LoadAnnotations3D(LoadAnnotations):
pts_instance_mask_path, dtype=np.long) pts_instance_mask_path, dtype=np.long)
results['pts_instance_mask'] = pts_instance_mask results['pts_instance_mask'] = pts_instance_mask
results['pts_mask_fields'].append(results['pts_instance_mask']) results['pts_mask_fields'].append('pts_instance_mask')
return results return results
def _load_semantic_seg_3d(self, results): def _load_semantic_seg_3d(self, results):
...@@ -197,7 +197,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -197,7 +197,7 @@ class LoadAnnotations3D(LoadAnnotations):
pts_semantic_mask_path, dtype=np.long) pts_semantic_mask_path, dtype=np.long)
results['pts_semantic_mask'] = pts_semantic_mask results['pts_semantic_mask'] = pts_semantic_mask
results['pts_seg_fields'].append(results['pts_semantic_mask']) results['pts_seg_fields'].append('pts_semantic_mask')
return results return results
def __call__(self, results): def __call__(self, results):
......
import warnings
from copy import deepcopy
import mmcv
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose
@PIPELINES.register_module()
class MultiScaleFlipAug3D(object):
"""Test-time augmentation with multiple scales and flipping
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for
resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal" and "vertical". If flip_direction is list,
multiple flip augmentations will be applied.
It has no effect when flip == False. Default: "horizontal".
"""
def __init__(self,
transforms,
img_scale,
pts_scale_ratio,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.pts_scale_ratio = pts_scale_ratio \
if isinstance(pts_scale_ratio, list) else[float(pts_scale_ratio)]
assert mmcv.is_list_of(self.img_scale, tuple)
assert mmcv.is_list_of(self.pts_scale_ratio, float)
self.flip = flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip
and not any([t['type'] == 'RandomFlip' for t in transforms])):
warnings.warn(
'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results):
aug_data = []
flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale:
for pts_scale_ratio in self.pts_scale_ratio:
for flip in flip_aug:
for direction in self.flip_direction:
# results.copy will cause bug since it is shallow copy
_results = deepcopy(results)
_results['scale'] = scale
_results['flip'] = flip
_results['pcd_scale_factor'] = pts_scale_ratio
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
repr_str += f'pts_scale_ratio={self.pts_scale_raio}, '
repr_str += f'flip_direction={self.flip_direction})'
return repr_str
import mmcv
import numpy as np import numpy as np
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
...@@ -18,6 +17,10 @@ class RandomFlip3D(RandomFlip): ...@@ -18,6 +17,10 @@ class RandomFlip3D(RandomFlip):
method. method.
Args: Args:
sync_2d (bool, optional): Whether to apply flip according to the 2D
images. If True, it will apply the same flip as that to 2D images.
If False, it will decide whether to flip randomly and independently
to that of 2D images.
flip_ratio (float, optional): The flipping probability. flip_ratio (float, optional): The flipping probability.
""" """
...@@ -25,61 +28,23 @@ class RandomFlip3D(RandomFlip): ...@@ -25,61 +28,23 @@ class RandomFlip3D(RandomFlip):
super(RandomFlip3D, self).__init__(**kwargs) super(RandomFlip3D, self).__init__(**kwargs)
self.sync_2d = sync_2d self.sync_2d = sync_2d
def random_flip_points(self, gt_bboxes_3d, points): def random_flip_data_3d(self, input_dict):
gt_bboxes_3d.flip() input_dict['points'][:, 1] = -input_dict['points'][:, 1]
points[:, 1] = -points[:, 1] for key in input_dict['bbox3d_fields']:
return gt_bboxes_3d, points input_dict[key].flip()
def __call__(self, input_dict): def __call__(self, input_dict):
# filp 2D image and its annotations # filp 2D image and its annotations
if 'flip' not in input_dict: super(RandomFlip3D, self).__call__(input_dict)
flip = True if np.random.rand() < self.flip_ratio else False
input_dict['flip'] = flip
if 'flip_direction' not in input_dict:
input_dict['flip_direction'] = self.direction
if input_dict['flip']:
# flip image
if 'img' in input_dict:
if isinstance(input_dict['img'], list):
input_dict['img'] = [
mmcv.imflip(
img, direction=input_dict['flip_direction'])
for img in input_dict['img']
]
else:
input_dict['img'] = mmcv.imflip(
input_dict['img'],
direction=input_dict['flip_direction'])
# flip bboxes
for key in input_dict.get('bbox_fields', []):
input_dict[key] = self.bbox_flip(input_dict[key],
input_dict['img_shape'],
input_dict['flip_direction'])
# flip masks
for key in input_dict.get('mask_fields', []):
input_dict[key] = [
mmcv.imflip(mask, direction=input_dict['flip_direction'])
for mask in input_dict[key]
]
# flip segs
for key in input_dict.get('seg_fields', []):
input_dict[key] = mmcv.imflip(
input_dict[key], direction=input_dict['flip_direction'])
if self.sync_2d: if self.sync_2d:
input_dict['pcd_flip'] = input_dict['flip'] input_dict['pcd_flip'] = input_dict['flip']
else: else:
flip = True if np.random.rand() < self.flip_ratio else False flip = True if np.random.rand() < self.flip_ratio else False
input_dict['pcd_flip'] = flip input_dict['pcd_flip'] = flip
if input_dict['pcd_flip']: if input_dict['pcd_flip']:
# flip image self.random_flip_data_3d(input_dict)
gt_bboxes_3d = input_dict['gt_bboxes_3d']
points = input_dict['points']
gt_bboxes_3d, points = self.random_flip_points(
gt_bboxes_3d, points)
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['points'] = points
return input_dict return input_dict
def __repr__(self): def __repr__(self):
...@@ -89,6 +54,13 @@ class RandomFlip3D(RandomFlip): ...@@ -89,6 +54,13 @@ class RandomFlip3D(RandomFlip):
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectSample(object): class ObjectSample(object):
"""Sample GT objects to the data
Args:
db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste.
"""
def __init__(self, db_sampler, sample_2d=False): def __init__(self, db_sampler, sample_2d=False):
self.sampler_cfg = db_sampler self.sampler_cfg = db_sampler
...@@ -109,9 +81,6 @@ class ObjectSample(object): ...@@ -109,9 +81,6 @@ class ObjectSample(object):
# change to float for blending operation # change to float for blending operation
points = input_dict['points'] points = input_dict['points']
# rect = input_dict['rect']
# Trv2c = input_dict['Trv2c']
# P2 = input_dict['P2']
if self.sample_2d: if self.sample_2d:
img = input_dict['img'] img = input_dict['img']
gt_bboxes_2d = input_dict['gt_bboxes'] gt_bboxes_2d = input_dict['gt_bboxes']
...@@ -162,15 +131,28 @@ class ObjectSample(object): ...@@ -162,15 +131,28 @@ class ObjectSample(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectNoise(object): class ObjectNoise(object):
"""Apply noise to each GT objects in the scene
Args:
translation_std (list, optional): Standard deviation of the
distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25].
global_rot_range (list, optional): Global rotation to the scene.
Defaults to [0.0, 0.0].
rot_range (list, optional): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267].
num_try (int, optional): Number of times to try if the noise applied is
invalid. Defaults to 100.
"""
def __init__(self, def __init__(self,
loc_noise_std=[0.25, 0.25, 0.25], translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_uniform_noise=[-0.15707963267, 0.15707963267], rot_range=[-0.15707963267, 0.15707963267],
num_try=100): num_try=100):
self.loc_noise_std = loc_noise_std self.translation_std = translation_std
self.global_rot_range = global_rot_range self.global_rot_range = global_rot_range
self.rot_uniform_noise = rot_uniform_noise self.rot_range = rot_range
self.num_try = num_try self.num_try = num_try
def __call__(self, input_dict): def __call__(self, input_dict):
...@@ -182,8 +164,8 @@ class ObjectNoise(object): ...@@ -182,8 +164,8 @@ class ObjectNoise(object):
noise_per_object_v3_( noise_per_object_v3_(
numpy_box, numpy_box,
points, points,
rotation_perturb=self.rot_uniform_noise, rotation_perturb=self.rot_range,
center_noise_std=self.loc_noise_std, center_noise_std=self.translation_std,
global_random_rot_range=self.global_rot_range, global_random_rot_range=self.global_rot_range,
num_try=self.num_try) num_try=self.num_try)
...@@ -194,73 +176,92 @@ class ObjectNoise(object): ...@@ -194,73 +176,92 @@ class ObjectNoise(object):
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(num_try={},'.format(self.num_try) repr_str += '(num_try={},'.format(self.num_try)
repr_str += ' loc_noise_std={},'.format(self.loc_noise_std) repr_str += ' translation_std={},'.format(self.translation_std)
repr_str += ' global_rot_range={},'.format(self.global_rot_range) repr_str += ' global_rot_range={},'.format(self.global_rot_range)
repr_str += ' rot_uniform_noise={})'.format(self.rot_uniform_noise) repr_str += ' rot_range={})'.format(self.rot_range)
return repr_str return repr_str
@PIPELINES.register_module() @PIPELINES.register_module()
class GlobalRotScale(object): class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene
Args:
rot_range (list[float]): Range of rotation angle.
Default to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float]): Range of scale ratio.
Default to [0.95, 1.05].
translation_std (list[float]): The standard deviation of ranslation
noise. This apply random translation to a scene by a noise, which
is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Default to [0, 0, 0]
"""
def __init__(self, def __init__(self,
rot_uniform_noise=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
scaling_uniform_noise=[0.95, 1.05], scale_ratio_range=[0.95, 1.05],
trans_normal_noise=[0, 0, 0]): translation_std=[0, 0, 0]):
self.rot_uniform_noise = rot_uniform_noise self.rot_range = rot_range
self.scaling_uniform_noise = scaling_uniform_noise self.scale_ratio_range = scale_ratio_range
self.trans_normal_noise = trans_normal_noise self.translation_std = translation_std
def _trans_bbox_points(self, gt_boxes, points): def _trans_bbox_points(self, input_dict):
noise_trans = np.random.normal(0, self.trans_normal_noise[0], 3).T if not isinstance(self.translation_std, (list, tuple, np.ndarray)):
points[:, :3] += noise_trans translation_std = [
gt_boxes.translate(noise_trans) self.translation_std, self.translation_std,
return gt_boxes, points, noise_trans self.translation_std
]
def _rot_bbox_points(self, gt_boxes, points, rotation=np.pi / 4): else:
translation_std = self.translation_std
translation_std = np.array(translation_std, dtype=np.float32)
trans_factor = np.random.normal(scale=translation_std, size=3).T
input_dict['points'][:, :3] += trans_factor
input_dict['pcd_trans'] = trans_factor
for key in input_dict['bbox3d_fields']:
input_dict[key].translate(trans_factor)
def _rot_bbox_points(self, input_dict):
rotation = self.rot_range
if not isinstance(rotation, list): if not isinstance(rotation, list):
rotation = [-rotation, rotation] rotation = [-rotation, rotation]
noise_rotation = np.random.uniform(rotation[0], rotation[1]) noise_rotation = np.random.uniform(rotation[0], rotation[1])
points = input_dict['points']
points[:, :3], rot_mat_T = box_np_ops.rotation_points_single_angle( points[:, :3], rot_mat_T = box_np_ops.rotation_points_single_angle(
points[:, :3], noise_rotation, axis=2) points[:, :3], noise_rotation, axis=2)
gt_boxes.rotate(noise_rotation) input_dict['points'] = points
input_dict['pcd_rotation'] = rot_mat_T
for key in input_dict['bbox3d_fields']:
input_dict[key].rotate(noise_rotation)
return gt_boxes, points, rot_mat_T def _scale_bbox_points(self, input_dict):
scale = input_dict['pcd_scale_factor']
input_dict['points'][:, :3] *= scale
for key in input_dict['bbox3d_fields']:
input_dict[key].scale(scale)
def _scale_bbox_points(self, def _random_scale(self, input_dict):
gt_boxes, scale_factor = np.random.uniform(self.scale_ratio_range[0],
points, self.scale_ratio_range[1])
min_scale=0.95, input_dict['pcd_scale_factor'] = scale_factor
max_scale=1.05):
noise_scale = np.random.uniform(min_scale, max_scale)
points[:, :3] *= noise_scale
gt_boxes.scale(noise_scale)
return gt_boxes, points, noise_scale
def __call__(self, input_dict): def __call__(self, input_dict):
gt_bboxes_3d = input_dict['gt_bboxes_3d'] self._rot_bbox_points(input_dict)
points = input_dict['points']
gt_bboxes_3d, points, rotation_factor = self._rot_bbox_points( if 'pcd_scale_factor' not in input_dict:
gt_bboxes_3d, points, rotation=self.rot_uniform_noise) self._random_scale(input_dict)
gt_bboxes_3d, points, scale_factor = self._scale_bbox_points( self._scale_bbox_points(input_dict)
gt_bboxes_3d, points, *self.scaling_uniform_noise)
gt_bboxes_3d, points, trans_factor = self._trans_bbox_points(
gt_bboxes_3d, points)
input_dict['gt_bboxes_3d'] = gt_bboxes_3d self._trans_bbox_points(input_dict)
input_dict['points'] = points
input_dict['pcd_scale_factor'] = scale_factor
input_dict['pcd_rotation'] = rotation_factor
input_dict['pcd_trans'] = trans_factor
return input_dict return input_dict
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(rot_uniform_noise={},'.format(self.rot_uniform_noise) repr_str += '(rot_range={},'.format(self.rot_range)
repr_str += ' scaling_uniform_noise={},'.format( repr_str += ' scale_ratio_range={},'.format(self.scale_ratio_range)
self.scaling_uniform_noise) repr_str += ' translation_std={})'.format(self.translation_std)
repr_str += ' trans_normal_noise={})'.format(self.trans_normal_noise)
return repr_str return repr_str
......
...@@ -181,7 +181,7 @@ class VoteHead(nn.Module): ...@@ -181,7 +181,7 @@ class VoteHead(nn.Module):
gt_labels_3d, gt_labels_3d,
pts_semantic_mask=None, pts_semantic_mask=None,
pts_instance_mask=None, pts_instance_mask=None,
input_meta=None, img_metas=None,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
"""Compute loss. """Compute loss.
...@@ -193,7 +193,7 @@ class VoteHead(nn.Module): ...@@ -193,7 +193,7 @@ class VoteHead(nn.Module):
gt_labels_3d (list[Tensor]): Gt labels of each sample. gt_labels_3d (list[Tensor]): Gt labels of each sample.
pts_semantic_mask (None | list[Tensor]): Point-wise semantic mask. pts_semantic_mask (None | list[Tensor]): Point-wise semantic mask.
pts_instance_mask (None | list[Tensor]): Point-wise instance mask. pts_instance_mask (None | list[Tensor]): Point-wise instance mask.
input_metas (list[dict]): Contain pcd and img's meta info. img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding. gt_bboxes_ignore (None | list[Tensor]): Specify which bounding.
Returns: Returns:
......
from .base import BaseDetector from .base import Base3DDetector
from .mvx_faster_rcnn import (DynamicMVXFasterRCNN, DynamicMVXFasterRCNNV2, from .dynamic_voxelnet import DynamicVoxelNet
DynamicMVXFasterRCNNV3) from .mvx_faster_rcnn import DynamicMVXFasterRCNN, DynamicMVXFasterRCNNV2
from .mvx_single_stage import MVXSingleStageDetector from .mvx_single_stage import MVXSingleStageDetector
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2 from .parta2 import PartA2
from .votenet import VoteNet from .votenet import VoteNet
from .voxelnet import DynamicVoxelNet, VoxelNet from .voxelnet import VoxelNet
__all__ = [ __all__ = [
'BaseDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXSingleStageDetector', 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXSingleStageDetector',
'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'DynamicMVXFasterRCNNV2', 'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'DynamicMVXFasterRCNNV2',
'DynamicMVXFasterRCNNV3', 'PartA2', 'VoteNet' 'PartA2', 'VoteNet'
] ]
...@@ -3,27 +3,17 @@ from abc import ABCMeta, abstractmethod ...@@ -3,27 +3,17 @@ from abc import ABCMeta, abstractmethod
import torch.nn as nn import torch.nn as nn
class BaseDetector(nn.Module, metaclass=ABCMeta): class Base3DDetector(nn.Module, metaclass=ABCMeta):
"""Base class for detectors""" """Base class for detectors"""
def __init__(self): def __init__(self):
super(BaseDetector, self).__init__() super(Base3DDetector, self).__init__()
self.fp16_enabled = False self.fp16_enabled = False
@property @property
def with_neck(self): def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None return hasattr(self, 'neck') and self.neck is not None
@property
def with_voxel_encoder(self):
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
@property @property
def with_shared_head(self): def with_shared_head(self):
return hasattr(self, 'shared_head') and self.shared_head is not None return hasattr(self, 'shared_head') and self.shared_head is not None
...@@ -63,48 +53,50 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): ...@@ -63,48 +53,50 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
logger = get_root_logger() logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained)) logger.info('load model from: {}'.format(pretrained))
def forward_test(self, imgs, img_metas, **kwargs): def forward_test(self, points, img_metas, imgs=None, **kwargs):
""" """
Args: Args:
imgs (List[Tensor]): the outer list indicates test-time points (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW, augmentations and inner Tensor should have a shape NxC,
which contains all images in the batch. which contains all points in the batch.
img_meta (List[List[dict]]): the outer list indicates test-time img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates augs (multiscale, flip, etc.) and the inner list indicates
images in a batch images in a batch
imgs (List[Tensor], optional): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch. Defaults to None.
""" """
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list): if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format( raise TypeError('{} must be a list, but got {}'.format(
name, type(var))) name, type(var)))
num_augs = len(imgs) num_augs = len(points)
if num_augs != len(img_metas): if num_augs != len(img_metas):
raise ValueError( raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format( 'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas))) len(points), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0) samples_per_gpu = len(points[0])
assert imgs_per_gpu == 1 assert samples_per_gpu == 1
if num_augs == 1: if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs) imgs = [imgs] if imgs is None else imgs
return self.simple_test(points[0], img_metas[0], imgs[0], **kwargs)
else: else:
return self.aug_test(imgs, img_metas, **kwargs) return self.aug_test(points, img_metas, imgs, **kwargs)
def forward(self, img, img_meta, return_loss=True, **kwargs): def forward(self, return_loss=True, **kwargs):
""" """
Calls either forward_train or forward_test depending on whether Calls either forward_train or forward_test depending on whether
return_loss=True. Note this setting will change the expected inputs. return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_meta are single-nested (i.e. When `return_loss=True`, img and img_metas are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta Tensor and List[dict]), and when `resturn_loss=False`, img and
should be double nested (i.e. List[Tensor], List[List[dict]]), with img_metas should be double nested
the outer list indicating test time augmentations. (i.e. List[Tensor], List[List[dict]]), with the outer list
indicating test time augmentations.
""" """
# TODO: current version only support 2D detector now, find
# a better way to be compatible with both
if return_loss: if return_loss:
return self.forward_train(img, img_meta, **kwargs) return self.forward_train(**kwargs)
else: else:
return self.forward_test(img, img_meta, **kwargs) return self.forward_test(**kwargs)
import torch
import torch.nn.functional as F
from mmdet.models import DETECTORS
from .voxelnet import VoxelNet
@DETECTORS.register_module()
class DynamicVoxelNet(VoxelNet):
def __init__(self,
voxel_layer,
voxel_encoder,
middle_encoder,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(DynamicVoxelNet, self).__init__(
voxel_layer=voxel_layer,
voxel_encoder=voxel_encoder,
middle_encoder=middle_encoder,
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
def extract_feat(self, points, img_metas):
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0].item() + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.backbone(x)
if self.with_neck:
x = self.neck(x)
return x
@torch.no_grad()
def voxelize(self, points):
coors = []
# dynamic voxelization only provide a coors mapping
for res in points:
res_coors = self.voxel_layer(res)
coors.append(res_coors)
points = torch.cat(points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return points, coors_batch
...@@ -11,13 +11,14 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector): ...@@ -11,13 +11,14 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs) super(DynamicMVXFasterRCNN, self).__init__(**kwargs)
def extract_pts_feat(self, points, img_feats, img_meta): def extract_pts_feat(self, points, img_feats, img_metas):
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
# adopt an early fusion strategy # adopt an early fusion strategy
if self.with_fusion: if self.with_fusion:
voxels = self.pts_fusion_layer(img_feats, points, voxels, img_meta) voxels = self.pts_fusion_layer(img_feats, points, voxels,
img_metas)
voxel_features, feature_coors = self.pts_voxel_encoder(voxels, coors) voxel_features, feature_coors = self.pts_voxel_encoder(voxels, coors)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size) x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
...@@ -48,12 +49,12 @@ class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN): ...@@ -48,12 +49,12 @@ class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DynamicMVXFasterRCNNV2, self).__init__(**kwargs) super(DynamicMVXFasterRCNNV2, self).__init__(**kwargs)
def extract_pts_feat(self, points, img_feats, img_meta): def extract_pts_feat(self, points, img_feats, img_metas):
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.pts_voxel_encoder( voxel_features, feature_coors = self.pts_voxel_encoder(
voxels, coors, points, img_feats, img_meta) voxels, coors, points, img_feats, img_metas)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size) x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x) x = self.pts_backbone(x)
...@@ -68,12 +69,12 @@ class MVXFasterRCNNV2(MVXTwoStageDetector): ...@@ -68,12 +69,12 @@ class MVXFasterRCNNV2(MVXTwoStageDetector):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(MVXFasterRCNNV2, self).__init__(**kwargs) super(MVXFasterRCNNV2, self).__init__(**kwargs)
def extract_pts_feat(self, pts, img_feats, img_meta): def extract_pts_feat(self, pts, img_feats, img_metas):
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, num_points, coors = self.voxelize(pts) voxels, num_points, coors = self.voxelize(pts)
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors, voxel_features = self.pts_voxel_encoder(voxels, num_points, coors,
img_feats, img_meta) img_feats, img_metas)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size) x = self.pts_middle_encoder(voxel_features, coors, batch_size)
...@@ -82,22 +83,3 @@ class MVXFasterRCNNV2(MVXTwoStageDetector): ...@@ -82,22 +83,3 @@ class MVXFasterRCNNV2(MVXTwoStageDetector):
if self.with_pts_neck: if self.with_pts_neck:
x = self.pts_neck(x) x = self.pts_neck(x)
return x return x
@DETECTORS.register_module()
class DynamicMVXFasterRCNNV3(DynamicMVXFasterRCNN):
def __init__(self, **kwargs):
super(DynamicMVXFasterRCNNV3, self).__init__(**kwargs)
def extract_pts_feat(self, points, img_feats, img_meta):
if not self.with_pts_bbox:
return None
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.pts_voxel_encoder(voxels, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x, coors, points, img_feats, img_meta)
return x
...@@ -6,11 +6,11 @@ from mmdet3d.core import bbox3d2result ...@@ -6,11 +6,11 @@ from mmdet3d.core import bbox3d2result
from mmdet3d.ops import Voxelization from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
from .. import builder from .. import builder
from .base import BaseDetector from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class MVXSingleStageDetector(BaseDetector): class MVXSingleStageDetector(SingleStage3DDetector):
def __init__(self, def __init__(self,
voxel_layer, voxel_layer,
...@@ -92,7 +92,7 @@ class MVXSingleStageDetector(BaseDetector): ...@@ -92,7 +92,7 @@ class MVXSingleStageDetector(BaseDetector):
def with_pts_neck(self): def with_pts_neck(self):
return hasattr(self, 'pts_neck') and self.pts_neck is not None return hasattr(self, 'pts_neck') and self.pts_neck is not None
def extract_feat(self, points, img, img_meta): def extract_feat(self, points, img, img_metas):
if self.with_img_backbone: if self.with_img_backbone:
img_feats = self.img_backbone(img) img_feats = self.img_backbone(img)
if self.with_img_neck: if self.with_img_neck:
...@@ -126,37 +126,28 @@ class MVXSingleStageDetector(BaseDetector): ...@@ -126,37 +126,28 @@ class MVXSingleStageDetector(BaseDetector):
def forward_train(self, def forward_train(self,
points, points,
img_meta, img_metas,
gt_bboxes_3d, gt_bboxes_3d,
gt_labels, gt_labels,
img=None, img=None,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
x = self.extract_feat(points, img=img, img_meta=img_meta) x = self.extract_feat(points, img=img, img_metas=img_metas)
outs = self.pts_bbox_head(x) outs = self.pts_bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels, img_meta) loss_inputs = outs + (gt_bboxes_3d, gt_labels, img_metas)
losses = self.pts_bbox_head.loss( losses = self.pts_bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self, def simple_test(self,
points, points,
img_meta, img_metas,
img=None, img=None,
gt_bboxes_3d=None, gt_bboxes_3d=None,
rescale=False): rescale=False):
x = self.extract_feat(points, img, img_meta) x = self.extract_feat(points, img, img_metas)
outs = self.pts_bbox_head(x) outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes( bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale) *outs, img_metas, rescale=rescale)
bbox_results = [ bbox_results = [
bbox3d2result(bboxes, scores, labels) bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list for bboxes, scores, labels in bbox_list
...@@ -200,7 +191,7 @@ class DynamicMVXNet(MVXSingleStageDetector): ...@@ -200,7 +191,7 @@ class DynamicMVXNet(MVXSingleStageDetector):
pretrained=pretrained, pretrained=pretrained,
) )
def extract_feat(self, points, img, img_meta): def extract_feat(self, points, img, img_metas):
if self.with_img_backbone: if self.with_img_backbone:
img_feats = self.img_backbone(img) img_feats = self.img_backbone(img)
if self.with_img_neck: if self.with_img_neck:
...@@ -209,7 +200,7 @@ class DynamicMVXNet(MVXSingleStageDetector): ...@@ -209,7 +200,7 @@ class DynamicMVXNet(MVXSingleStageDetector):
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
# adopt an early fusion strategy # adopt an early fusion strategy
if self.with_fusion: if self.with_fusion:
voxels = self.fusion_layer(img_feats, points, voxels, img_meta) voxels = self.fusion_layer(img_feats, points, voxels, img_metas)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors) voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
...@@ -268,7 +259,7 @@ class DynamicMVXNetV2(DynamicMVXNet): ...@@ -268,7 +259,7 @@ class DynamicMVXNetV2(DynamicMVXNet):
pretrained=pretrained, pretrained=pretrained,
) )
def extract_feat(self, points, img, img_meta): def extract_feat(self, points, img, img_metas):
if self.with_img_backbone: if self.with_img_backbone:
img_feats = self.img_backbone(img) img_feats = self.img_backbone(img)
if self.with_img_neck: if self.with_img_neck:
...@@ -277,7 +268,7 @@ class DynamicMVXNetV2(DynamicMVXNet): ...@@ -277,7 +268,7 @@ class DynamicMVXNetV2(DynamicMVXNet):
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder( voxel_features, feature_coors = self.voxel_encoder(
voxels, coors, points, img_feats, img_meta) voxels, coors, points, img_feats, img_metas)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size) x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x) x = self.pts_backbone(x)
...@@ -319,7 +310,7 @@ class DynamicMVXNetV3(DynamicMVXNet): ...@@ -319,7 +310,7 @@ class DynamicMVXNetV3(DynamicMVXNet):
pretrained=pretrained, pretrained=pretrained,
) )
def extract_feat(self, points, img, img_meta): def extract_feat(self, points, img, img_metas):
if self.with_img_backbone: if self.with_img_backbone:
img_feats = self.img_backbone(img) img_feats = self.img_backbone(img)
if self.with_img_neck: if self.with_img_neck:
...@@ -331,5 +322,5 @@ class DynamicMVXNetV3(DynamicMVXNet): ...@@ -331,5 +322,5 @@ class DynamicMVXNetV3(DynamicMVXNet):
x = self.middle_encoder(voxel_features, feature_coors, batch_size) x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x) x = self.pts_backbone(x)
if self.with_pts_neck: if self.with_pts_neck:
x = self.pts_neck(x, coors, points, img_feats, img_meta) x = self.pts_neck(x, coors, points, img_feats, img_metas)
return x return x
...@@ -2,15 +2,15 @@ import torch ...@@ -2,15 +2,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet3d.core import bbox3d2result from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d, multi_apply
from mmdet3d.ops import Voxelization from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
from .. import builder from .. import builder
from .base import BaseDetector from .base import Base3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class MVXTwoStageDetector(BaseDetector): class MVXTwoStageDetector(Base3DDetector):
def __init__(self, def __init__(self,
pts_voxel_layer=None, pts_voxel_layer=None,
...@@ -137,7 +137,17 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -137,7 +137,17 @@ class MVXTwoStageDetector(BaseDetector):
def with_img_roi_head(self): def with_img_roi_head(self):
return hasattr(self, 'img_roi_head') and self.img_roi_head is not None return hasattr(self, 'img_roi_head') and self.img_roi_head is not None
def extract_img_feat(self, img, img_meta): @property
def with_voxel_encoder(self):
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
def extract_img_feat(self, img, img_metas):
if self.with_img_backbone: if self.with_img_backbone:
if img.dim() == 5 and img.size(0) == 1: if img.dim() == 5 and img.size(0) == 1:
img.squeeze_() img.squeeze_()
...@@ -151,7 +161,7 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -151,7 +161,7 @@ class MVXTwoStageDetector(BaseDetector):
img_feats = self.img_neck(img_feats) img_feats = self.img_neck(img_feats)
return img_feats return img_feats
def extract_pts_feat(self, pts, img_feats, img_meta): def extract_pts_feat(self, pts, img_feats, img_metas):
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, num_points, coors = self.voxelize(pts) voxels, num_points, coors = self.voxelize(pts)
...@@ -163,9 +173,9 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -163,9 +173,9 @@ class MVXTwoStageDetector(BaseDetector):
x = self.pts_neck(x) x = self.pts_neck(x)
return x return x
def extract_feat(self, points, img, img_meta): def extract_feat(self, points, img, img_metas):
img_feats = self.extract_img_feat(img, img_meta) img_feats = self.extract_img_feat(img, img_metas)
pts_feats = self.extract_pts_feat(points, img_feats, img_meta) pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
return (img_feats, pts_feats) return (img_feats, pts_feats)
@torch.no_grad() @torch.no_grad()
...@@ -187,30 +197,30 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -187,30 +197,30 @@ class MVXTwoStageDetector(BaseDetector):
def forward_train(self, def forward_train(self,
points=None, points=None,
img_meta=None, img_metas=None,
gt_bboxes_3d=None, gt_bboxes_3d=None,
gt_labels_3d=None, gt_labels_3d=None,
gt_labels=None, gt_labels=None,
gt_bboxes=None, gt_bboxes=None,
img=None, img=None,
proposals=None, bboxes=None,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
img_feats, pts_feats = self.extract_feat( img_feats, pts_feats = self.extract_feat(
points, img=img, img_meta=img_meta) points, img=img, img_metas=img_metas)
losses = dict() losses = dict()
if pts_feats: if pts_feats:
losses_pts = self.forward_pts_train(pts_feats, gt_bboxes_3d, losses_pts = self.forward_pts_train(pts_feats, gt_bboxes_3d,
gt_labels_3d, img_meta, gt_labels_3d, img_metas,
gt_bboxes_ignore) gt_bboxes_ignore)
losses.update(losses_pts) losses.update(losses_pts)
if img_feats: if img_feats:
losses_img = self.forward_img_train( losses_img = self.forward_img_train(
img_feats, img_feats,
img_meta=img_meta, img_metas=img_metas,
gt_bboxes=gt_bboxes, gt_bboxes=gt_bboxes,
gt_labels=gt_labels, gt_labels=gt_labels,
gt_bboxes_ignore=gt_bboxes_ignore, gt_bboxes_ignore=gt_bboxes_ignore,
proposals=proposals, bboxes=bboxes,
) )
losses.update(losses_img) losses.update(losses_img)
return losses return losses
...@@ -219,17 +229,17 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -219,17 +229,17 @@ class MVXTwoStageDetector(BaseDetector):
pts_feats, pts_feats,
gt_bboxes_3d, gt_bboxes_3d,
gt_labels_3d, gt_labels_3d,
img_meta, img_metas,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
outs = self.pts_bbox_head(pts_feats) outs = self.pts_bbox_head(pts_feats)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_meta) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.pts_bbox_head.loss( losses = self.pts_bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses return losses
def forward_img_train(self, def forward_img_train(self,
x, x,
img_meta, img_metas,
gt_bboxes, gt_bboxes,
gt_labels, gt_labels,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
...@@ -239,7 +249,7 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -239,7 +249,7 @@ class MVXTwoStageDetector(BaseDetector):
# RPN forward and loss # RPN forward and loss
if self.with_img_rpn: if self.with_img_rpn:
rpn_outs = self.img_rpn_head(x) rpn_outs = self.img_rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
self.train_cfg.img_rpn) self.train_cfg.img_rpn)
rpn_losses = self.img_rpn_head.loss( rpn_losses = self.img_rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
...@@ -247,13 +257,13 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -247,13 +257,13 @@ class MVXTwoStageDetector(BaseDetector):
proposal_cfg = self.train_cfg.get('img_rpn_proposal', proposal_cfg = self.train_cfg.get('img_rpn_proposal',
self.test_cfg.img_rpn) self.test_cfg.img_rpn)
proposal_inputs = rpn_outs + (img_meta, proposal_cfg) proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
# bbox head forward and loss # bbox head forward and loss
img_roi_losses = self.roi_head.forward_train(x, img_meta, img_roi_losses = self.roi_head.forward_train(x, img_metas,
proposal_list, gt_bboxes, proposal_list, gt_bboxes,
gt_labels, gt_labels,
gt_bboxes_ignore, gt_bboxes_ignore,
...@@ -262,61 +272,78 @@ class MVXTwoStageDetector(BaseDetector): ...@@ -262,61 +272,78 @@ class MVXTwoStageDetector(BaseDetector):
losses.update(img_roi_losses) losses.update(img_roi_losses)
return losses return losses
def forward_test(self, **kwargs): def simple_test_img(self, x, img_metas, proposals=None, rescale=False):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test_img(self, x, img_meta, proposals=None, rescale=False):
"""Test without augmentation.""" """Test without augmentation."""
if proposals is None: if proposals is None:
proposal_list = self.simple_test_rpn(x, img_meta, proposal_list = self.simple_test_rpn(x, img_metas,
self.test_cfg.img_rpn) self.test_cfg.img_rpn)
else: else:
proposal_list = proposals proposal_list = proposals
return self.img_roi_head.simple_test( return self.img_roi_head.simple_test(
x, proposal_list, img_meta, rescale=rescale) x, proposal_list, img_metas, rescale=rescale)
def simple_test_rpn(self, x, img_meta, rpn_test_cfg): def simple_test_rpn(self, x, img_metas, rpn_test_cfg):
rpn_outs = self.img_rpn_head(x) rpn_outs = self.img_rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs) proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
return proposal_list return proposal_list
def simple_test_pts(self, x, img_meta, rescale=False): def simple_test_pts(self, x, img_metas, rescale=False):
outs = self.pts_bbox_head(x) outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes( bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale) *outs, img_metas, rescale=rescale)
bbox_results = [ bbox_results = [
bbox3d2result(bboxes, scores, labels) bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list for bboxes, scores, labels in bbox_list
] ]
return bbox_results[0] return bbox_results[0]
def simple_test(self, def simple_test(self, points, img_metas, img=None, rescale=False):
points,
img_meta,
img=None,
gt_bboxes_3d=None,
rescale=False):
img_feats, pts_feats = self.extract_feat( img_feats, pts_feats = self.extract_feat(
points, img=img, img_meta=img_meta) points, img=img, img_metas=img_metas)
bbox_list = dict() bbox_list = dict()
if pts_feats and self.with_pts_bbox: if pts_feats and self.with_pts_bbox:
bbox_pts = self.simple_test_pts( bbox_pts = self.simple_test_pts(
pts_feats, img_meta, rescale=rescale) pts_feats, img_metas, rescale=rescale)
bbox_list.update(pts_bbox=bbox_pts) bbox_list.update(pts_bbox=bbox_pts)
if img_feats and self.with_img_bbox: if img_feats and self.with_img_bbox:
bbox_img = self.simple_test_img( bbox_img = self.simple_test_img(
img_feats, img_meta, rescale=rescale) img_feats, img_metas, rescale=rescale)
bbox_list.update(img_bbox=bbox_img) bbox_list.update(img_bbox=bbox_img)
return bbox_list return bbox_list
def aug_test(self, points, imgs, img_metas, rescale=False): def aug_test(self, points, img_metas, imgs=None, rescale=False):
raise NotImplementedError img_feats, pts_feats = self.extract_feats(points, img_metas, imgs)
bbox_list = dict()
if pts_feats and self.with_pts_bbox:
bbox_pts = self.aug_test_pts(pts_feats, img_metas, rescale)
bbox_list.update(pts_bbox=bbox_pts)
return bbox_list
def extract_feats(self, points, img_metas, imgs=None):
if imgs is None:
imgs = [None] * len(img_metas)
img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs,
img_metas)
return img_feats, pts_feats
def aug_test_pts(self, feats, img_metas, rescale=False):
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.pts_bbox_head.test_cfg)
return merged_bboxes
...@@ -33,7 +33,7 @@ class PartA2(TwoStageDetector): ...@@ -33,7 +33,7 @@ class PartA2(TwoStageDetector):
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_meta): def extract_feat(self, points, img_metas):
voxel_dict = self.voxelize(points) voxel_dict = self.voxelize(points)
voxel_features = self.voxel_encoder(voxel_dict['voxels'], voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'], voxel_dict['num_points'],
...@@ -79,39 +79,66 @@ class PartA2(TwoStageDetector): ...@@ -79,39 +79,66 @@ class PartA2(TwoStageDetector):
def forward_train(self, def forward_train(self,
points, points,
img_meta, img_metas,
gt_bboxes_3d, gt_bboxes_3d,
gt_labels_3d, gt_labels_3d,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
proposals=None): proposals=None):
feats_dict, voxels_dict = self.extract_feat(points, img_meta) feats_dict, voxels_dict = self.extract_feat(points, img_metas)
losses = dict() losses = dict()
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict['neck_feats']) rpn_outs = self.rpn_head(feats_dict['neck_feats'])
rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d, img_meta) rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d,
img_metas)
rpn_losses = self.rpn_head.loss( rpn_losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_cfg = self.train_cfg.get('rpn_proposal', proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn) self.test_cfg.rpn)
proposal_inputs = rpn_outs + (img_meta, proposal_cfg) proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
roi_losses = self.roi_head.forward_train(feats_dict, voxels_dict, roi_losses = self.roi_head.forward_train(feats_dict, voxels_dict,
img_meta, proposal_list, img_metas, proposal_list,
gt_bboxes_3d, gt_labels_3d) gt_bboxes_3d, gt_labels_3d)
losses.update(roi_losses) losses.update(roi_losses)
return losses return losses
def forward_test(self, **kwargs): def forward_test(self, points, img_metas, imgs=None, **kwargs):
return self.simple_test(**kwargs) """
Args:
points (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
"""
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(points), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
samples_per_gpu = len(points[0])
assert samples_per_gpu == 1
if num_augs == 1:
return self.simple_test(points[0], img_metas[0], **kwargs)
else:
return self.aug_test(points, img_metas, **kwargs)
def forward(self, return_loss=True, **kwargs): def forward(self, return_loss=True, **kwargs):
if return_loss: if return_loss:
...@@ -119,16 +146,19 @@ class PartA2(TwoStageDetector): ...@@ -119,16 +146,19 @@ class PartA2(TwoStageDetector):
else: else:
return self.forward_test(**kwargs) return self.forward_test(**kwargs)
def simple_test(self, points, img_meta, proposals=None, rescale=False): def simple_test(self, points, img_metas, proposals=None, rescale=False):
feats_dict, voxels_dict = self.extract_feat(points, img_meta) feats_dict, voxels_dict = self.extract_feat(points, img_metas)
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict['neck_feats']) rpn_outs = self.rpn_head(feats_dict['neck_feats'])
proposal_cfg = self.test_cfg.rpn proposal_cfg = self.test_cfg.rpn
bbox_inputs = rpn_outs + (img_meta, proposal_cfg) bbox_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.rpn_head.get_bboxes(*bbox_inputs) proposal_list = self.rpn_head.get_bboxes(*bbox_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
return self.roi_head.simple_test(feats_dict, voxels_dict, img_meta, return self.roi_head.simple_test(feats_dict, voxels_dict, img_metas,
proposal_list) proposal_list)
def aug_test(self, **kwargs):
raise NotImplementedError
import torch.nn as nn
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector
@DETECTORS.register_module()
class SingleStage3DDetector(Base3DDetector):
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStage3DDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
super(SingleStage3DDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, points, img_metas=None):
"""Directly extract features from the backbone+neck
"""
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def extract_feats(self, points, img_metas):
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
import torch import torch
from mmdet3d.core import bbox3d2result from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS, SingleStageDetector from mmdet.models import DETECTORS
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class VoteNet(SingleStageDetector): class VoteNet(SingleStage3DDetector):
"""VoteNet model. """VoteNet model.
https://arxiv.org/pdf/1904.09664.pdf https://arxiv.org/pdf/1904.09664.pdf
...@@ -24,15 +25,9 @@ class VoteNet(SingleStageDetector): ...@@ -24,15 +25,9 @@ class VoteNet(SingleStageDetector):
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained) pretrained=pretrained)
def extract_feat(self, points):
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self, def forward_train(self,
points, points,
img_meta, img_metas,
gt_bboxes_3d, gt_bboxes_3d,
gt_labels_3d, gt_labels_3d,
pts_semantic_mask=None, pts_semantic_mask=None,
...@@ -42,7 +37,7 @@ class VoteNet(SingleStageDetector): ...@@ -42,7 +37,7 @@ class VoteNet(SingleStageDetector):
Args: Args:
points (list[Tensor]): Points of each batch. points (list[Tensor]): Points of each batch.
img_meta (list): Image metas. img_metas (list): Image metas.
gt_bboxes_3d (:obj:BaseInstance3DBoxes): gt bboxes of each batch. gt_bboxes_3d (:obj:BaseInstance3DBoxes): gt bboxes of each batch.
gt_labels_3d (list[Tensor]): gt class labels of each batch. gt_labels_3d (list[Tensor]): gt class labels of each batch.
pts_semantic_mask (None | list[Tensor]): point-wise semantic pts_semantic_mask (None | list[Tensor]): point-wise semantic
...@@ -54,57 +49,57 @@ class VoteNet(SingleStageDetector): ...@@ -54,57 +49,57 @@ class VoteNet(SingleStageDetector):
Returns: Returns:
dict: Losses. dict: Losses.
""" """
points_cat = torch.stack(points) # tmp points_cat = torch.stack(points)
x = self.extract_feat(points_cat) x = self.extract_feat(points_cat)
bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod) bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod)
loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_meta) pts_instance_mask, img_metas)
losses = self.bbox_head.loss( losses = self.bbox_head.loss(
bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses return losses
def forward_test(self, **kwargs): def simple_test(self, points, img_metas, imgs=None, rescale=False):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self,
points,
img_meta,
gt_bboxes_3d=None,
gt_labels_3d=None,
pts_semantic_mask=None,
pts_instance_mask=None,
rescale=False):
"""Forward of testing. """Forward of testing.
Args: Args:
points (list[Tensor]): Points of each sample. points (list[Tensor]): Points of each sample.
img_meta (list): Image metas. img_metas (list): Image metas.
gt_bboxes_3d (:obj:BaseInstance3DBoxes): gt bboxes of each sample.
gt_labels_3d (list[Tensor]): gt class labels of each sample.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
label of each sample.
pts_instance_mask (None | list[Tensor]): point-wise instance
label of each sample.
rescale (bool): Whether to rescale results. rescale (bool): Whether to rescale results.
Returns: Returns:
list: Predicted 3d boxes. list: Predicted 3d boxes.
""" """
points_cat = torch.stack(points) # tmp points_cat = torch.stack(points)
x = self.extract_feat(points_cat) x = self.extract_feat(points_cat)
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
bbox_list = self.bbox_head.get_bboxes( bbox_list = self.bbox_head.get_bboxes(
points_cat, bbox_preds, img_meta, rescale=rescale) points_cat, bbox_preds, img_metas, rescale=rescale)
bbox_results = [ bbox_results = [
bbox3d2result(bboxes, scores, labels) bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list for bboxes, scores, labels in bbox_list
] ]
return bbox_results[0] return bbox_results[0]
def aug_test(self, points, img_metas, imgs=None, rescale=False):
points_cat = [torch.stack(pts) for pts in points]
feats = self.extract_feats(points_cat, img_metas)
# only support aug_test for one sample
aug_bboxes = []
for x, pts_cat, img_meta in zip(feats, points_cat, img_metas):
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
bbox_list = self.bbox_head.get_bboxes(
pts_cat, bbox_preds, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return merged_bboxes
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmdet3d.core import bbox3d2result from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.ops import Voxelization from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS, SingleStageDetector from mmdet.models import DETECTORS
from .. import builder from .. import builder
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @DETECTORS.register_module()
class VoxelNet(SingleStageDetector): class VoxelNet(SingleStage3DDetector):
def __init__(self, def __init__(self,
voxel_layer, voxel_layer,
...@@ -32,7 +33,7 @@ class VoxelNet(SingleStageDetector): ...@@ -32,7 +33,7 @@ class VoxelNet(SingleStageDetector):
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_meta): def extract_feat(self, points, img_metas):
voxels, num_points, coors = self.voxelize(points) voxels, num_points, coors = self.voxelize(points)
voxel_features = self.voxel_encoder(voxels, num_points, coors) voxel_features = self.voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0].item() + 1 batch_size = coors[-1, 0].item() + 1
...@@ -61,83 +62,45 @@ class VoxelNet(SingleStageDetector): ...@@ -61,83 +62,45 @@ class VoxelNet(SingleStageDetector):
def forward_train(self, def forward_train(self,
points, points,
img_meta, img_metas,
gt_bboxes_3d, gt_bboxes_3d,
gt_labels_3d, gt_labels_3d,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
x = self.extract_feat(points, img_meta) x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x) outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_meta) loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.bbox_head.loss( losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses return losses
def forward_test(self, **kwargs): def simple_test(self, points, img_metas, imgs=None, rescale=False):
return self.simple_test(**kwargs) x = self.extract_feat(points, img_metas)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self, points, img_meta, gt_bboxes_3d=None, rescale=False):
x = self.extract_feat(points, img_meta)
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*outs, img_meta, rescale=rescale) bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
bbox_results = [ bbox_results = [
bbox3d2result(bboxes, scores, labels) bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list for bboxes, scores, labels in bbox_list
] ]
return bbox_results[0] return bbox_results[0]
def aug_test(self, points, img_metas, imgs=None, rescale=False):
feats = self.extract_feats(points, img_metas)
@DETECTORS.register_module() # only support aug_test for one sample
class DynamicVoxelNet(VoxelNet): aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
def __init__(self, # after merging, bboxes will be rescaled to the original image size
voxel_layer, merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
voxel_encoder, self.bbox_head.test_cfg)
middle_encoder,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(DynamicVoxelNet, self).__init__(
voxel_layer=voxel_layer,
voxel_encoder=voxel_encoder,
middle_encoder=middle_encoder,
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
def extract_feat(self, points, img_meta): return merged_bboxes
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0].item() + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.backbone(x)
if self.with_neck:
x = self.neck(x)
return x
@torch.no_grad()
def voxelize(self, points):
coors = []
# dynamic voxelization only provide a coors mapping
for res in points:
res_coors = self.voxel_layer(res)
coors.append(res_coors)
points = torch.cat(points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return points, coors_batch
...@@ -23,18 +23,34 @@ def point_sample( ...@@ -23,18 +23,34 @@ def point_sample(
padding_mode='zeros', padding_mode='zeros',
align_corners=True, align_corners=True,
): ):
"""sample image features using point coordinates """Obtain image features using points
Arguments: Args:
img_features (Tensor): 1xCxHxW image features img_features (Tensor): 1xCxHxW image features
points (Tensor): Nx3 point cloud coordinates points (Tensor): Nx3 point cloud in LiDAR coordinates
P (Tensor): 4x4 transformation matrix lidar2img_rt (Tensor): 4x4 transformation matrix
scale_factor (Tensor): scale_factor of images pcd_rotate_mat (Tensor): 3x3 rotation matrix of points
img_pad_shape (int, int): int tuple indicates the h & w after padding, during augmentation
this is necessary to obtain features in feature map img_scale_factor (Tensor): (w_scale, h_scale)
img_shape (int, int): int tuple indicates the h & w before padding img_crop_offset (Tensor): (w_offset, h_offset) offset used to crop
image during data augmentation
pcd_trans_factor ([type]): Translation of points in augmentation
pcd_scale_factor (float): Scale factor of points during
data augmentation
pcd_flip (bool): Whether the points are flipped.
img_flip (bool): Whether the image is flipped.
img_pad_shape (tuple[int]): int tuple indicates the h & w after
padding, this is necessary to obtain features in feature map
img_shape (tuple[int]): int tuple indicates the h & w before padding
after scaling, this is necessary for flipping coordinates after scaling, this is necessary for flipping coordinates
return: aligned (bool, optional): Whether use bilinear interpolation when
sampling image features for each point. Defaults to True.
padding_mode (str, optional): Padding mode when padding values for
features of out-of-image points. Defaults to 'zeros'.
align_corners (bool, optional): Whether to align corners when
sampling image features for each point. Defaults to True.
Returns:
(Tensor): NxC image features sampled by point coordinates (Tensor): NxC image features sampled by point coordinates
""" """
# aug order: flip -> trans -> scale -> rot # aug order: flip -> trans -> scale -> rot
...@@ -97,7 +113,36 @@ def point_sample( ...@@ -97,7 +113,36 @@ def point_sample(
@FUSION_LAYERS.register_module() @FUSION_LAYERS.register_module()
class PointFusion(nn.Module): class PointFusion(nn.Module):
"""Fuse image features from fused single scale features """Fuse image features from multi-scale features
Args:
img_channels (list[int] | int): Channels of image features.
It could be a list if the input is multi-scale image features.
pts_channels (int): Channels of point features
mid_channels (int): Channels of middle layers
out_channels (int): Channels of output fused features
img_levels (int, optional): Number of image levels. Defaults to 3.
conv_cfg (dict, optional): Dict config of conv layers of middle
layers. Defaults to None.
norm_cfg (dict, optional): Dict config of norm layers of middle
layers. Defaults to None.
act_cfg (dict, optional): Dict config of activatation layers.
Defaults to None.
activate_out (bool, optional): Whether to apply relu activation
to output features. Defaults to True.
fuse_out (bool, optional): Whether apply conv layer to the fused
features. Defaults to False.
dropout_ratio (int, float, optional): Dropout ratio of image
features to prevent overfitting. Defaults to 0.
aligned (bool, optional): Whether apply aligned feature fusion.
Defaults to True.
align_corners (bool, optional): Whether to align corner when
sampling features according to points. Defaults to True.
padding_mode (str, optional): Mode used to pad the features of
points that do not have corresponding image features.
Defaults to 'zeros'.
lateral_conv (bool, optional): Whether to apply lateral convs
to image features. Defaults to True.
""" """
def __init__(self, def __init__(self,
...@@ -179,15 +224,20 @@ class PointFusion(nn.Module): ...@@ -179,15 +224,20 @@ class PointFusion(nn.Module):
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, distribution='uniform') xavier_init(m, distribution='uniform')
def forward(self, img_feats, pts, pts_feats, img_meta): def forward(self, img_feats, pts, pts_feats, img_metas):
""" """Forward function
img_feats (List[Tensor]): img features
pts: [List[Tensor]]: a batch of points with shape Nx3 Args:
pts_feats (Tensor): a tensor consist of point features of the img_feats (list[Tensor]): img features
total batch pts: [list[Tensor]]: a batch of points with shape Nx3
pts_feats (Tensor): a tensor consist of point features of the
total batch
img_metas (list[dict]): meta information of images
Returns:
torch.Tensor: fused features of each point.
""" """
img_pts = self.obtain_mlvl_feats(img_feats, pts, img_meta) img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas)
img_pre_fuse = self.img_transform(img_pts) img_pre_fuse = self.img_transform(img_pts)
if self.training and self.dropout_ratio > 0: if self.training and self.dropout_ratio > 0:
img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio) img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio)
...@@ -201,7 +251,7 @@ class PointFusion(nn.Module): ...@@ -201,7 +251,7 @@ class PointFusion(nn.Module):
return fuse_out return fuse_out
def obtain_mlvl_feats(self, img_feats, pts, img_meta): def obtain_mlvl_feats(self, img_feats, pts, img_metas):
if self.lateral_convs is not None: if self.lateral_convs is not None:
img_ins = [ img_ins = [
lateral_conv(img_feats[i]) lateral_conv(img_feats[i])
...@@ -211,7 +261,7 @@ class PointFusion(nn.Module): ...@@ -211,7 +261,7 @@ class PointFusion(nn.Module):
img_ins = img_feats img_ins = img_feats
img_feats_per_point = [] img_feats_per_point = []
# Sample multi-level features # Sample multi-level features
for i in range(len(img_meta)): for i in range(len(img_metas)):
mlvl_img_feats = [] mlvl_img_feats = []
for level in range(len(self.img_levels)): for level in range(len(self.img_levels)):
if torch.isnan(img_ins[level][i:i + 1]).any(): if torch.isnan(img_ins[level][i:i + 1]).any():
...@@ -219,7 +269,7 @@ class PointFusion(nn.Module): ...@@ -219,7 +269,7 @@ class PointFusion(nn.Module):
pdb.set_trace() pdb.set_trace()
mlvl_img_feats.append( mlvl_img_feats.append(
self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3], self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3],
img_meta[i])) img_metas[i]))
mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1) mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1)
img_feats_per_point.append(mlvl_img_feats) img_feats_per_point.append(mlvl_img_feats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment