Commit ce79da2e authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'add-tta' into 'master'

Support test time augmentation

See merge request open-mmlab/mmdet.3d!70
parents f6e95edd 3c5ff9fa
......@@ -3,12 +3,12 @@ from .custom_3d import Custom3DDataset
from .kitti2d_dataset import Kitti2DDataset
from .kitti_dataset import KittiDataset
from .nuscenes_dataset import NuScenesDataset
from .pipelines import (GlobalRotScale, IndoorFlipData, IndoorGlobalRotScale,
IndoorPointSample, IndoorPointsColorJitter,
LoadAnnotations3D, LoadPointsFromFile,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
from .pipelines import (GlobalRotScaleTrans, IndoorFlipData,
IndoorGlobalRotScaleTrans, IndoorPointSample,
IndoorPointsColorJitter, LoadAnnotations3D,
LoadPointsFromFile, NormalizePointsColor, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D)
from .scannet_dataset import ScanNetDataset
from .sunrgbd_dataset import SUNRGBDDataset
......@@ -17,9 +17,10 @@ __all__ = [
'build_dataloader', 'RepeatFactorDataset', 'DATASETS', 'build_dataset',
'build_dataloader'
'CocoDataset', 'Kitti2DDataset', 'NuScenesDataset', 'ObjectSample',
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScale', 'PointShuffle',
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle',
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'IndoorPointsColorJitter', 'IndoorGlobalRotScale',
'IndoorFlipData', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset'
'LoadAnnotations3D', 'IndoorPointsColorJitter',
'IndoorGlobalRotScaleTrans', 'IndoorFlipData', 'SUNRGBDDataset',
'ScanNetDataset', 'Custom3DDataset'
]
......@@ -103,9 +103,13 @@ class Custom3DDataset(Dataset):
return input_dict
def pre_pipeline(self, results):
results['img_fields'] = []
results['bbox3d_fields'] = []
results['pts_mask_fields'] = []
results['pts_seg_fields'] = []
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
results['box_type_3d'] = self.box_type_3d
results['box_mode_3d'] = self.box_mode_3d
......
from mmdet.datasets.pipelines import Compose
from .dbsampler import DataBaseSampler, MMDataBaseSampler
from .formating import DefaultFormatBundle, DefaultFormatBundle3D
from .indoor_augment import (IndoorFlipData, IndoorGlobalRotScale,
from .indoor_augment import (IndoorFlipData, IndoorGlobalRotScaleTrans,
IndoorPointsColorJitter)
from .indoor_loading import (LoadAnnotations3D, LoadPointsFromFile,
NormalizePointsColor)
from .indoor_sample import IndoorPointSample
from .loading import LoadMultiViewImageFromFiles
from .point_seg_class_mapping import PointSegClassMapping
from .train_aug import (GlobalRotScale, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (GlobalRotScaleTrans, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D)
__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScale',
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'IndoorGlobalRotScale', 'IndoorPointsColorJitter', 'IndoorFlipData',
'IndoorGlobalRotScaleTrans', 'IndoorPointsColorJitter', 'IndoorFlipData',
'MMDataBaseSampler', 'NormalizePointsColor', 'LoadAnnotations3D',
'IndoorPointSample', 'PointSegClassMapping'
'IndoorPointSample', 'PointSegClassMapping', 'MultiScaleFlipAug3D'
]
......@@ -83,12 +83,12 @@ class Collect3D(object):
def __call__(self, results):
data = {}
img_meta = {}
img_metas = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
img_metas[key] = results[key]
data['img_meta'] = DC(img_meta, cpu_only=True)
data['img_metas'] = DC(img_metas, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
......
......@@ -117,7 +117,7 @@ class IndoorPointsColorJitter(object):
# TODO: merge outdoor indoor transform.
# TODO: try transform noise.
@PIPELINES.register_module()
class IndoorGlobalRotScale(object):
class IndoorGlobalRotScaleTrans(object):
"""Indoor global rotate and scale.
Augment sunrgbd and scannet data with global rotating and scaling.
......
......@@ -158,7 +158,7 @@ class LoadAnnotations3D(LoadAnnotations):
def _load_bboxes_3d(self, results):
results['gt_bboxes_3d'] = results['ann_info']['gt_bboxes_3d']
results['bbox3d_fields'].append(results['gt_bboxes_3d'])
results['bbox3d_fields'].append('gt_bboxes_3d')
return results
def _load_labels_3d(self, results):
......@@ -179,7 +179,7 @@ class LoadAnnotations3D(LoadAnnotations):
pts_instance_mask_path, dtype=np.long)
results['pts_instance_mask'] = pts_instance_mask
results['pts_mask_fields'].append(results['pts_instance_mask'])
results['pts_mask_fields'].append('pts_instance_mask')
return results
def _load_semantic_seg_3d(self, results):
......@@ -197,7 +197,7 @@ class LoadAnnotations3D(LoadAnnotations):
pts_semantic_mask_path, dtype=np.long)
results['pts_semantic_mask'] = pts_semantic_mask
results['pts_seg_fields'].append(results['pts_semantic_mask'])
results['pts_seg_fields'].append('pts_semantic_mask')
return results
def __call__(self, results):
......
import warnings
from copy import deepcopy
import mmcv
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose
@PIPELINES.register_module()
class MultiScaleFlipAug3D(object):
"""Test-time augmentation with multiple scales and flipping
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for
resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal" and "vertical". If flip_direction is list,
multiple flip augmentations will be applied.
It has no effect when flip == False. Default: "horizontal".
"""
def __init__(self,
transforms,
img_scale,
pts_scale_ratio,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.pts_scale_ratio = pts_scale_ratio \
if isinstance(pts_scale_ratio, list) else[float(pts_scale_ratio)]
assert mmcv.is_list_of(self.img_scale, tuple)
assert mmcv.is_list_of(self.pts_scale_ratio, float)
self.flip = flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip
and not any([t['type'] == 'RandomFlip' for t in transforms])):
warnings.warn(
'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results):
aug_data = []
flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale:
for pts_scale_ratio in self.pts_scale_ratio:
for flip in flip_aug:
for direction in self.flip_direction:
# results.copy will cause bug since it is shallow copy
_results = deepcopy(results)
_results['scale'] = scale
_results['flip'] = flip
_results['pcd_scale_factor'] = pts_scale_ratio
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
repr_str += f'pts_scale_ratio={self.pts_scale_raio}, '
repr_str += f'flip_direction={self.flip_direction})'
return repr_str
import mmcv
import numpy as np
from mmcv.utils import build_from_cfg
......@@ -18,6 +17,10 @@ class RandomFlip3D(RandomFlip):
method.
Args:
sync_2d (bool, optional): Whether to apply flip according to the 2D
images. If True, it will apply the same flip as that to 2D images.
If False, it will decide whether to flip randomly and independently
to that of 2D images.
flip_ratio (float, optional): The flipping probability.
"""
......@@ -25,61 +28,23 @@ class RandomFlip3D(RandomFlip):
super(RandomFlip3D, self).__init__(**kwargs)
self.sync_2d = sync_2d
def random_flip_points(self, gt_bboxes_3d, points):
gt_bboxes_3d.flip()
points[:, 1] = -points[:, 1]
return gt_bboxes_3d, points
def random_flip_data_3d(self, input_dict):
input_dict['points'][:, 1] = -input_dict['points'][:, 1]
for key in input_dict['bbox3d_fields']:
input_dict[key].flip()
def __call__(self, input_dict):
# filp 2D image and its annotations
if 'flip' not in input_dict:
flip = True if np.random.rand() < self.flip_ratio else False
input_dict['flip'] = flip
if 'flip_direction' not in input_dict:
input_dict['flip_direction'] = self.direction
if input_dict['flip']:
# flip image
if 'img' in input_dict:
if isinstance(input_dict['img'], list):
input_dict['img'] = [
mmcv.imflip(
img, direction=input_dict['flip_direction'])
for img in input_dict['img']
]
else:
input_dict['img'] = mmcv.imflip(
input_dict['img'],
direction=input_dict['flip_direction'])
# flip bboxes
for key in input_dict.get('bbox_fields', []):
input_dict[key] = self.bbox_flip(input_dict[key],
input_dict['img_shape'],
input_dict['flip_direction'])
# flip masks
for key in input_dict.get('mask_fields', []):
input_dict[key] = [
mmcv.imflip(mask, direction=input_dict['flip_direction'])
for mask in input_dict[key]
]
# flip segs
for key in input_dict.get('seg_fields', []):
input_dict[key] = mmcv.imflip(
input_dict[key], direction=input_dict['flip_direction'])
super(RandomFlip3D, self).__call__(input_dict)
if self.sync_2d:
input_dict['pcd_flip'] = input_dict['flip']
else:
flip = True if np.random.rand() < self.flip_ratio else False
input_dict['pcd_flip'] = flip
if input_dict['pcd_flip']:
# flip image
gt_bboxes_3d = input_dict['gt_bboxes_3d']
points = input_dict['points']
gt_bboxes_3d, points = self.random_flip_points(
gt_bboxes_3d, points)
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['points'] = points
self.random_flip_data_3d(input_dict)
return input_dict
def __repr__(self):
......@@ -89,6 +54,13 @@ class RandomFlip3D(RandomFlip):
@PIPELINES.register_module()
class ObjectSample(object):
"""Sample GT objects to the data
Args:
db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste.
"""
def __init__(self, db_sampler, sample_2d=False):
self.sampler_cfg = db_sampler
......@@ -109,9 +81,6 @@ class ObjectSample(object):
# change to float for blending operation
points = input_dict['points']
# rect = input_dict['rect']
# Trv2c = input_dict['Trv2c']
# P2 = input_dict['P2']
if self.sample_2d:
img = input_dict['img']
gt_bboxes_2d = input_dict['gt_bboxes']
......@@ -162,15 +131,28 @@ class ObjectSample(object):
@PIPELINES.register_module()
class ObjectNoise(object):
"""Apply noise to each GT objects in the scene
Args:
translation_std (list, optional): Standard deviation of the
distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25].
global_rot_range (list, optional): Global rotation to the scene.
Defaults to [0.0, 0.0].
rot_range (list, optional): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267].
num_try (int, optional): Number of times to try if the noise applied is
invalid. Defaults to 100.
"""
def __init__(self,
loc_noise_std=[0.25, 0.25, 0.25],
translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0],
rot_uniform_noise=[-0.15707963267, 0.15707963267],
rot_range=[-0.15707963267, 0.15707963267],
num_try=100):
self.loc_noise_std = loc_noise_std
self.translation_std = translation_std
self.global_rot_range = global_rot_range
self.rot_uniform_noise = rot_uniform_noise
self.rot_range = rot_range
self.num_try = num_try
def __call__(self, input_dict):
......@@ -182,8 +164,8 @@ class ObjectNoise(object):
noise_per_object_v3_(
numpy_box,
points,
rotation_perturb=self.rot_uniform_noise,
center_noise_std=self.loc_noise_std,
rotation_perturb=self.rot_range,
center_noise_std=self.translation_std,
global_random_rot_range=self.global_rot_range,
num_try=self.num_try)
......@@ -194,73 +176,92 @@ class ObjectNoise(object):
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(num_try={},'.format(self.num_try)
repr_str += ' loc_noise_std={},'.format(self.loc_noise_std)
repr_str += ' translation_std={},'.format(self.translation_std)
repr_str += ' global_rot_range={},'.format(self.global_rot_range)
repr_str += ' rot_uniform_noise={})'.format(self.rot_uniform_noise)
repr_str += ' rot_range={})'.format(self.rot_range)
return repr_str
@PIPELINES.register_module()
class GlobalRotScale(object):
class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene
Args:
rot_range (list[float]): Range of rotation angle.
Default to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float]): Range of scale ratio.
Default to [0.95, 1.05].
translation_std (list[float]): The standard deviation of ranslation
noise. This apply random translation to a scene by a noise, which
is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Default to [0, 0, 0]
"""
def __init__(self,
rot_uniform_noise=[-0.78539816, 0.78539816],
scaling_uniform_noise=[0.95, 1.05],
trans_normal_noise=[0, 0, 0]):
self.rot_uniform_noise = rot_uniform_noise
self.scaling_uniform_noise = scaling_uniform_noise
self.trans_normal_noise = trans_normal_noise
def _trans_bbox_points(self, gt_boxes, points):
noise_trans = np.random.normal(0, self.trans_normal_noise[0], 3).T
points[:, :3] += noise_trans
gt_boxes.translate(noise_trans)
return gt_boxes, points, noise_trans
def _rot_bbox_points(self, gt_boxes, points, rotation=np.pi / 4):
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0]):
self.rot_range = rot_range
self.scale_ratio_range = scale_ratio_range
self.translation_std = translation_std
def _trans_bbox_points(self, input_dict):
if not isinstance(self.translation_std, (list, tuple, np.ndarray)):
translation_std = [
self.translation_std, self.translation_std,
self.translation_std
]
else:
translation_std = self.translation_std
translation_std = np.array(translation_std, dtype=np.float32)
trans_factor = np.random.normal(scale=translation_std, size=3).T
input_dict['points'][:, :3] += trans_factor
input_dict['pcd_trans'] = trans_factor
for key in input_dict['bbox3d_fields']:
input_dict[key].translate(trans_factor)
def _rot_bbox_points(self, input_dict):
rotation = self.rot_range
if not isinstance(rotation, list):
rotation = [-rotation, rotation]
noise_rotation = np.random.uniform(rotation[0], rotation[1])
points = input_dict['points']
points[:, :3], rot_mat_T = box_np_ops.rotation_points_single_angle(
points[:, :3], noise_rotation, axis=2)
gt_boxes.rotate(noise_rotation)
input_dict['points'] = points
input_dict['pcd_rotation'] = rot_mat_T
for key in input_dict['bbox3d_fields']:
input_dict[key].rotate(noise_rotation)
return gt_boxes, points, rot_mat_T
def _scale_bbox_points(self, input_dict):
scale = input_dict['pcd_scale_factor']
input_dict['points'][:, :3] *= scale
for key in input_dict['bbox3d_fields']:
input_dict[key].scale(scale)
def _scale_bbox_points(self,
gt_boxes,
points,
min_scale=0.95,
max_scale=1.05):
noise_scale = np.random.uniform(min_scale, max_scale)
points[:, :3] *= noise_scale
gt_boxes.scale(noise_scale)
return gt_boxes, points, noise_scale
def _random_scale(self, input_dict):
scale_factor = np.random.uniform(self.scale_ratio_range[0],
self.scale_ratio_range[1])
input_dict['pcd_scale_factor'] = scale_factor
def __call__(self, input_dict):
gt_bboxes_3d = input_dict['gt_bboxes_3d']
points = input_dict['points']
self._rot_bbox_points(input_dict)
gt_bboxes_3d, points, rotation_factor = self._rot_bbox_points(
gt_bboxes_3d, points, rotation=self.rot_uniform_noise)
gt_bboxes_3d, points, scale_factor = self._scale_bbox_points(
gt_bboxes_3d, points, *self.scaling_uniform_noise)
gt_bboxes_3d, points, trans_factor = self._trans_bbox_points(
gt_bboxes_3d, points)
if 'pcd_scale_factor' not in input_dict:
self._random_scale(input_dict)
self._scale_bbox_points(input_dict)
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['points'] = points
input_dict['pcd_scale_factor'] = scale_factor
input_dict['pcd_rotation'] = rotation_factor
input_dict['pcd_trans'] = trans_factor
self._trans_bbox_points(input_dict)
return input_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(rot_uniform_noise={},'.format(self.rot_uniform_noise)
repr_str += ' scaling_uniform_noise={},'.format(
self.scaling_uniform_noise)
repr_str += ' trans_normal_noise={})'.format(self.trans_normal_noise)
repr_str += '(rot_range={},'.format(self.rot_range)
repr_str += ' scale_ratio_range={},'.format(self.scale_ratio_range)
repr_str += ' translation_std={})'.format(self.translation_std)
return repr_str
......
......@@ -181,7 +181,7 @@ class VoteHead(nn.Module):
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
input_meta=None,
img_metas=None,
gt_bboxes_ignore=None):
"""Compute loss.
......@@ -193,7 +193,7 @@ class VoteHead(nn.Module):
gt_labels_3d (list[Tensor]): Gt labels of each sample.
pts_semantic_mask (None | list[Tensor]): Point-wise semantic mask.
pts_instance_mask (None | list[Tensor]): Point-wise instance mask.
input_metas (list[dict]): Contain pcd and img's meta info.
img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding.
Returns:
......
from .base import BaseDetector
from .mvx_faster_rcnn import (DynamicMVXFasterRCNN, DynamicMVXFasterRCNNV2,
DynamicMVXFasterRCNNV3)
from .base import Base3DDetector
from .dynamic_voxelnet import DynamicVoxelNet
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, DynamicMVXFasterRCNNV2
from .mvx_single_stage import MVXSingleStageDetector
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
from .votenet import VoteNet
from .voxelnet import DynamicVoxelNet, VoxelNet
from .voxelnet import VoxelNet
__all__ = [
'BaseDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXSingleStageDetector',
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXSingleStageDetector',
'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'DynamicMVXFasterRCNNV2',
'DynamicMVXFasterRCNNV3', 'PartA2', 'VoteNet'
'PartA2', 'VoteNet'
]
......@@ -3,27 +3,17 @@ from abc import ABCMeta, abstractmethod
import torch.nn as nn
class BaseDetector(nn.Module, metaclass=ABCMeta):
class Base3DDetector(nn.Module, metaclass=ABCMeta):
"""Base class for detectors"""
def __init__(self):
super(BaseDetector, self).__init__()
super(Base3DDetector, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None
@property
def with_voxel_encoder(self):
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
@property
def with_shared_head(self):
return hasattr(self, 'shared_head') and self.shared_head is not None
......@@ -63,48 +53,50 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
logger = get_root_logger()
logger.info('load model from: {}'.format(pretrained))
def forward_test(self, imgs, img_metas, **kwargs):
def forward_test(self, points, img_metas, imgs=None, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_meta (List[List[dict]]): the outer list indicates test-time
points (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
imgs (List[Tensor], optional): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch. Defaults to None.
"""
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
len(points), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
samples_per_gpu = len(points[0])
assert samples_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
imgs = [imgs] if imgs is None else imgs
return self.simple_test(points[0], img_metas[0], imgs[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
return self.aug_test(points, img_metas, imgs, **kwargs)
def forward(self, img, img_meta, return_loss=True, **kwargs):
def forward(self, return_loss=True, **kwargs):
"""
Calls either forward_train or forward_test depending on whether
return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_meta are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
When `return_loss=True`, img and img_metas are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested
(i.e. List[Tensor], List[List[dict]]), with the outer list
indicating test time augmentations.
"""
# TODO: current version only support 2D detector now, find
# a better way to be compatible with both
if return_loss:
return self.forward_train(img, img_meta, **kwargs)
return self.forward_train(**kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
return self.forward_test(**kwargs)
import torch
import torch.nn.functional as F
from mmdet.models import DETECTORS
from .voxelnet import VoxelNet
@DETECTORS.register_module()
class DynamicVoxelNet(VoxelNet):
def __init__(self,
voxel_layer,
voxel_encoder,
middle_encoder,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(DynamicVoxelNet, self).__init__(
voxel_layer=voxel_layer,
voxel_encoder=voxel_encoder,
middle_encoder=middle_encoder,
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
def extract_feat(self, points, img_metas):
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0].item() + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.backbone(x)
if self.with_neck:
x = self.neck(x)
return x
@torch.no_grad()
def voxelize(self, points):
coors = []
# dynamic voxelization only provide a coors mapping
for res in points:
res_coors = self.voxel_layer(res)
coors.append(res_coors)
points = torch.cat(points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return points, coors_batch
......@@ -11,13 +11,14 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
def __init__(self, **kwargs):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs)
def extract_pts_feat(self, points, img_feats, img_meta):
def extract_pts_feat(self, points, img_feats, img_metas):
if not self.with_pts_bbox:
return None
voxels, coors = self.voxelize(points)
# adopt an early fusion strategy
if self.with_fusion:
voxels = self.pts_fusion_layer(img_feats, points, voxels, img_meta)
voxels = self.pts_fusion_layer(img_feats, points, voxels,
img_metas)
voxel_features, feature_coors = self.pts_voxel_encoder(voxels, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
......@@ -48,12 +49,12 @@ class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN):
def __init__(self, **kwargs):
super(DynamicMVXFasterRCNNV2, self).__init__(**kwargs)
def extract_pts_feat(self, points, img_feats, img_meta):
def extract_pts_feat(self, points, img_feats, img_metas):
if not self.with_pts_bbox:
return None
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.pts_voxel_encoder(
voxels, coors, points, img_feats, img_meta)
voxels, coors, points, img_feats, img_metas)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x)
......@@ -68,12 +69,12 @@ class MVXFasterRCNNV2(MVXTwoStageDetector):
def __init__(self, **kwargs):
super(MVXFasterRCNNV2, self).__init__(**kwargs)
def extract_pts_feat(self, pts, img_feats, img_meta):
def extract_pts_feat(self, pts, img_feats, img_metas):
if not self.with_pts_bbox:
return None
voxels, num_points, coors = self.voxelize(pts)
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors,
img_feats, img_meta)
img_feats, img_metas)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size)
......@@ -82,22 +83,3 @@ class MVXFasterRCNNV2(MVXTwoStageDetector):
if self.with_pts_neck:
x = self.pts_neck(x)
return x
@DETECTORS.register_module()
class DynamicMVXFasterRCNNV3(DynamicMVXFasterRCNN):
def __init__(self, **kwargs):
super(DynamicMVXFasterRCNNV3, self).__init__(**kwargs)
def extract_pts_feat(self, points, img_feats, img_meta):
if not self.with_pts_bbox:
return None
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.pts_voxel_encoder(voxels, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x, coors, points, img_feats, img_meta)
return x
......@@ -6,11 +6,11 @@ from mmdet3d.core import bbox3d2result
from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS
from .. import builder
from .base import BaseDetector
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module()
class MVXSingleStageDetector(BaseDetector):
class MVXSingleStageDetector(SingleStage3DDetector):
def __init__(self,
voxel_layer,
......@@ -92,7 +92,7 @@ class MVXSingleStageDetector(BaseDetector):
def with_pts_neck(self):
return hasattr(self, 'pts_neck') and self.pts_neck is not None
def extract_feat(self, points, img, img_meta):
def extract_feat(self, points, img, img_metas):
if self.with_img_backbone:
img_feats = self.img_backbone(img)
if self.with_img_neck:
......@@ -126,37 +126,28 @@ class MVXSingleStageDetector(BaseDetector):
def forward_train(self,
points,
img_meta,
img_metas,
gt_bboxes_3d,
gt_labels,
img=None,
gt_bboxes_ignore=None):
x = self.extract_feat(points, img=img, img_meta=img_meta)
x = self.extract_feat(points, img=img, img_metas=img_metas)
outs = self.pts_bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels, img_meta)
loss_inputs = outs + (gt_bboxes_3d, gt_labels, img_metas)
losses = self.pts_bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self,
points,
img_meta,
img_metas,
img=None,
gt_bboxes_3d=None,
rescale=False):
x = self.extract_feat(points, img, img_meta)
x = self.extract_feat(points, img, img_metas)
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
......@@ -200,7 +191,7 @@ class DynamicMVXNet(MVXSingleStageDetector):
pretrained=pretrained,
)
def extract_feat(self, points, img, img_meta):
def extract_feat(self, points, img, img_metas):
if self.with_img_backbone:
img_feats = self.img_backbone(img)
if self.with_img_neck:
......@@ -209,7 +200,7 @@ class DynamicMVXNet(MVXSingleStageDetector):
voxels, coors = self.voxelize(points)
# adopt an early fusion strategy
if self.with_fusion:
voxels = self.fusion_layer(img_feats, points, voxels, img_meta)
voxels = self.fusion_layer(img_feats, points, voxels, img_metas)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0] + 1
......@@ -268,7 +259,7 @@ class DynamicMVXNetV2(DynamicMVXNet):
pretrained=pretrained,
)
def extract_feat(self, points, img, img_meta):
def extract_feat(self, points, img, img_metas):
if self.with_img_backbone:
img_feats = self.img_backbone(img)
if self.with_img_neck:
......@@ -277,7 +268,7 @@ class DynamicMVXNetV2(DynamicMVXNet):
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder(
voxels, coors, points, img_feats, img_meta)
voxels, coors, points, img_feats, img_metas)
batch_size = coors[-1, 0] + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x)
......@@ -319,7 +310,7 @@ class DynamicMVXNetV3(DynamicMVXNet):
pretrained=pretrained,
)
def extract_feat(self, points, img, img_meta):
def extract_feat(self, points, img, img_metas):
if self.with_img_backbone:
img_feats = self.img_backbone(img)
if self.with_img_neck:
......@@ -331,5 +322,5 @@ class DynamicMVXNetV3(DynamicMVXNet):
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x, coors, points, img_feats, img_meta)
x = self.pts_neck(x, coors, points, img_feats, img_metas)
return x
......@@ -2,15 +2,15 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet3d.core import bbox3d2result
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d, multi_apply
from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS
from .. import builder
from .base import BaseDetector
from .base import Base3DDetector
@DETECTORS.register_module()
class MVXTwoStageDetector(BaseDetector):
class MVXTwoStageDetector(Base3DDetector):
def __init__(self,
pts_voxel_layer=None,
......@@ -137,7 +137,17 @@ class MVXTwoStageDetector(BaseDetector):
def with_img_roi_head(self):
return hasattr(self, 'img_roi_head') and self.img_roi_head is not None
def extract_img_feat(self, img, img_meta):
@property
def with_voxel_encoder(self):
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
def extract_img_feat(self, img, img_metas):
if self.with_img_backbone:
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
......@@ -151,7 +161,7 @@ class MVXTwoStageDetector(BaseDetector):
img_feats = self.img_neck(img_feats)
return img_feats
def extract_pts_feat(self, pts, img_feats, img_meta):
def extract_pts_feat(self, pts, img_feats, img_metas):
if not self.with_pts_bbox:
return None
voxels, num_points, coors = self.voxelize(pts)
......@@ -163,9 +173,9 @@ class MVXTwoStageDetector(BaseDetector):
x = self.pts_neck(x)
return x
def extract_feat(self, points, img, img_meta):
img_feats = self.extract_img_feat(img, img_meta)
pts_feats = self.extract_pts_feat(points, img_feats, img_meta)
def extract_feat(self, points, img, img_metas):
img_feats = self.extract_img_feat(img, img_metas)
pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
return (img_feats, pts_feats)
@torch.no_grad()
......@@ -187,30 +197,30 @@ class MVXTwoStageDetector(BaseDetector):
def forward_train(self,
points=None,
img_meta=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img=None,
proposals=None,
bboxes=None,
gt_bboxes_ignore=None):
img_feats, pts_feats = self.extract_feat(
points, img=img, img_meta=img_meta)
points, img=img, img_metas=img_metas)
losses = dict()
if pts_feats:
losses_pts = self.forward_pts_train(pts_feats, gt_bboxes_3d,
gt_labels_3d, img_meta,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
if img_feats:
losses_img = self.forward_img_train(
img_feats,
img_meta=img_meta,
img_metas=img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_bboxes_ignore=gt_bboxes_ignore,
proposals=proposals,
bboxes=bboxes,
)
losses.update(losses_img)
return losses
......@@ -219,17 +229,17 @@ class MVXTwoStageDetector(BaseDetector):
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_meta,
img_metas,
gt_bboxes_ignore=None):
outs = self.pts_bbox_head(pts_feats)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_meta)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.pts_bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def forward_img_train(self,
x,
img_meta,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
......@@ -239,7 +249,7 @@ class MVXTwoStageDetector(BaseDetector):
# RPN forward and loss
if self.with_img_rpn:
rpn_outs = self.img_rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
self.train_cfg.img_rpn)
rpn_losses = self.img_rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
......@@ -247,13 +257,13 @@ class MVXTwoStageDetector(BaseDetector):
proposal_cfg = self.train_cfg.get('img_rpn_proposal',
self.test_cfg.img_rpn)
proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
else:
proposal_list = proposals
# bbox head forward and loss
img_roi_losses = self.roi_head.forward_train(x, img_meta,
img_roi_losses = self.roi_head.forward_train(x, img_metas,
proposal_list, gt_bboxes,
gt_labels,
gt_bboxes_ignore,
......@@ -262,61 +272,78 @@ class MVXTwoStageDetector(BaseDetector):
losses.update(img_roi_losses)
return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test_img(self, x, img_meta, proposals=None, rescale=False):
def simple_test_img(self, x, img_metas, proposals=None, rescale=False):
"""Test without augmentation."""
if proposals is None:
proposal_list = self.simple_test_rpn(x, img_meta,
proposal_list = self.simple_test_rpn(x, img_metas,
self.test_cfg.img_rpn)
else:
proposal_list = proposals
return self.img_roi_head.simple_test(
x, proposal_list, img_meta, rescale=rescale)
x, proposal_list, img_metas, rescale=rescale)
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
def simple_test_rpn(self, x, img_metas, rpn_test_cfg):
rpn_outs = self.img_rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
return proposal_list
def simple_test_pts(self, x, img_meta, rescale=False):
def simple_test_pts(self, x, img_metas, rescale=False):
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
def simple_test(self,
points,
img_meta,
img=None,
gt_bboxes_3d=None,
rescale=False):
def simple_test(self, points, img_metas, img=None, rescale=False):
img_feats, pts_feats = self.extract_feat(
points, img=img, img_meta=img_meta)
points, img=img, img_metas=img_metas)
bbox_list = dict()
if pts_feats and self.with_pts_bbox:
bbox_pts = self.simple_test_pts(
pts_feats, img_meta, rescale=rescale)
pts_feats, img_metas, rescale=rescale)
bbox_list.update(pts_bbox=bbox_pts)
if img_feats and self.with_img_bbox:
bbox_img = self.simple_test_img(
img_feats, img_meta, rescale=rescale)
img_feats, img_metas, rescale=rescale)
bbox_list.update(img_bbox=bbox_img)
return bbox_list
def aug_test(self, points, imgs, img_metas, rescale=False):
raise NotImplementedError
def aug_test(self, points, img_metas, imgs=None, rescale=False):
img_feats, pts_feats = self.extract_feats(points, img_metas, imgs)
bbox_list = dict()
if pts_feats and self.with_pts_bbox:
bbox_pts = self.aug_test_pts(pts_feats, img_metas, rescale)
bbox_list.update(pts_bbox=bbox_pts)
return bbox_list
def extract_feats(self, points, img_metas, imgs=None):
if imgs is None:
imgs = [None] * len(img_metas)
img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs,
img_metas)
return img_feats, pts_feats
def aug_test_pts(self, feats, img_metas, rescale=False):
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.pts_bbox_head.test_cfg)
return merged_bboxes
......@@ -33,7 +33,7 @@ class PartA2(TwoStageDetector):
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_meta):
def extract_feat(self, points, img_metas):
voxel_dict = self.voxelize(points)
voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
......@@ -79,39 +79,66 @@ class PartA2(TwoStageDetector):
def forward_train(self,
points,
img_meta,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
gt_bboxes_ignore=None,
proposals=None):
feats_dict, voxels_dict = self.extract_feat(points, img_meta)
feats_dict, voxels_dict = self.extract_feat(points, img_metas)
losses = dict()
if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict['neck_feats'])
rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d, img_meta)
rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d,
img_metas)
rpn_losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(rpn_losses)
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else:
proposal_list = proposals
roi_losses = self.roi_head.forward_train(feats_dict, voxels_dict,
img_meta, proposal_list,
img_metas, proposal_list,
gt_bboxes_3d, gt_labels_3d)
losses.update(roi_losses)
return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward_test(self, points, img_metas, imgs=None, **kwargs):
"""
Args:
points (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
"""
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(points), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
samples_per_gpu = len(points[0])
assert samples_per_gpu == 1
if num_augs == 1:
return self.simple_test(points[0], img_metas[0], **kwargs)
else:
return self.aug_test(points, img_metas, **kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
......@@ -119,16 +146,19 @@ class PartA2(TwoStageDetector):
else:
return self.forward_test(**kwargs)
def simple_test(self, points, img_meta, proposals=None, rescale=False):
feats_dict, voxels_dict = self.extract_feat(points, img_meta)
def simple_test(self, points, img_metas, proposals=None, rescale=False):
feats_dict, voxels_dict = self.extract_feat(points, img_metas)
if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict['neck_feats'])
proposal_cfg = self.test_cfg.rpn
bbox_inputs = rpn_outs + (img_meta, proposal_cfg)
bbox_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.rpn_head.get_bboxes(*bbox_inputs)
else:
proposal_list = proposals
return self.roi_head.simple_test(feats_dict, voxels_dict, img_meta,
return self.roi_head.simple_test(feats_dict, voxels_dict, img_metas,
proposal_list)
def aug_test(self, **kwargs):
raise NotImplementedError
import torch.nn as nn
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector
@DETECTORS.register_module()
class SingleStage3DDetector(Base3DDetector):
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStage3DDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
super(SingleStage3DDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, points, img_metas=None):
"""Directly extract features from the backbone+neck
"""
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def extract_feats(self, points, img_metas):
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
import torch
from mmdet3d.core import bbox3d2result
from mmdet.models import DETECTORS, SingleStageDetector
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module()
class VoteNet(SingleStageDetector):
class VoteNet(SingleStage3DDetector):
"""VoteNet model.
https://arxiv.org/pdf/1904.09664.pdf
......@@ -24,15 +25,9 @@ class VoteNet(SingleStageDetector):
test_cfg=test_cfg,
pretrained=pretrained)
def extract_feat(self, points):
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
points,
img_meta,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
......@@ -42,7 +37,7 @@ class VoteNet(SingleStageDetector):
Args:
points (list[Tensor]): Points of each batch.
img_meta (list): Image metas.
img_metas (list): Image metas.
gt_bboxes_3d (:obj:BaseInstance3DBoxes): gt bboxes of each batch.
gt_labels_3d (list[Tensor]): gt class labels of each batch.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
......@@ -54,57 +49,57 @@ class VoteNet(SingleStageDetector):
Returns:
dict: Losses.
"""
points_cat = torch.stack(points) # tmp
points_cat = torch.stack(points)
x = self.extract_feat(points_cat)
bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod)
loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_meta)
pts_instance_mask, img_metas)
losses = self.bbox_head.loss(
bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self,
points,
img_meta,
gt_bboxes_3d=None,
gt_labels_3d=None,
pts_semantic_mask=None,
pts_instance_mask=None,
rescale=False):
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Forward of testing.
Args:
points (list[Tensor]): Points of each sample.
img_meta (list): Image metas.
gt_bboxes_3d (:obj:BaseInstance3DBoxes): gt bboxes of each sample.
gt_labels_3d (list[Tensor]): gt class labels of each sample.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
label of each sample.
pts_instance_mask (None | list[Tensor]): point-wise instance
label of each sample.
img_metas (list): Image metas.
rescale (bool): Whether to rescale results.
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points) # tmp
points_cat = torch.stack(points)
x = self.extract_feat(points_cat)
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
bbox_list = self.bbox_head.get_bboxes(
points_cat, bbox_preds, img_meta, rescale=rescale)
points_cat, bbox_preds, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
def aug_test(self, points, img_metas, imgs=None, rescale=False):
points_cat = [torch.stack(pts) for pts in points]
feats = self.extract_feats(points_cat, img_metas)
# only support aug_test for one sample
aug_bboxes = []
for x, pts_cat, img_meta in zip(feats, points_cat, img_metas):
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
bbox_list = self.bbox_head.get_bboxes(
pts_cat, bbox_preds, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return merged_bboxes
import torch
import torch.nn.functional as F
from mmdet3d.core import bbox3d2result
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS, SingleStageDetector
from mmdet.models import DETECTORS
from .. import builder
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module()
class VoxelNet(SingleStageDetector):
class VoxelNet(SingleStage3DDetector):
def __init__(self,
voxel_layer,
......@@ -32,7 +33,7 @@ class VoxelNet(SingleStageDetector):
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_meta):
def extract_feat(self, points, img_metas):
voxels, num_points, coors = self.voxelize(points)
voxel_features = self.voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0].item() + 1
......@@ -61,83 +62,45 @@ class VoxelNet(SingleStageDetector):
def forward_train(self,
points,
img_meta,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
gt_bboxes_ignore=None):
x = self.extract_feat(points, img_meta)
x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_meta)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def forward_test(self, **kwargs):
return self.simple_test(**kwargs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def simple_test(self, points, img_meta, gt_bboxes_3d=None, rescale=False):
x = self.extract_feat(points, img_meta)
def simple_test(self, points, img_metas, imgs=None, rescale=False):
x = self.extract_feat(points, img_metas)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*outs, img_meta, rescale=rescale)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
def aug_test(self, points, img_metas, imgs=None, rescale=False):
feats = self.extract_feats(points, img_metas)
@DETECTORS.register_module()
class DynamicVoxelNet(VoxelNet):
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
def __init__(self,
voxel_layer,
voxel_encoder,
middle_encoder,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(DynamicVoxelNet, self).__init__(
voxel_layer=voxel_layer,
voxel_encoder=voxel_encoder,
middle_encoder=middle_encoder,
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
def extract_feat(self, points, img_meta):
voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
batch_size = coors[-1, 0].item() + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
x = self.backbone(x)
if self.with_neck:
x = self.neck(x)
return x
@torch.no_grad()
def voxelize(self, points):
coors = []
# dynamic voxelization only provide a coors mapping
for res in points:
res_coors = self.voxel_layer(res)
coors.append(res_coors)
points = torch.cat(points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return points, coors_batch
return merged_bboxes
......@@ -23,18 +23,34 @@ def point_sample(
padding_mode='zeros',
align_corners=True,
):
"""sample image features using point coordinates
"""Obtain image features using points
Arguments:
Args:
img_features (Tensor): 1xCxHxW image features
points (Tensor): Nx3 point cloud coordinates
P (Tensor): 4x4 transformation matrix
scale_factor (Tensor): scale_factor of images
img_pad_shape (int, int): int tuple indicates the h & w after padding,
this is necessary to obtain features in feature map
img_shape (int, int): int tuple indicates the h & w before padding
points (Tensor): Nx3 point cloud in LiDAR coordinates
lidar2img_rt (Tensor): 4x4 transformation matrix
pcd_rotate_mat (Tensor): 3x3 rotation matrix of points
during augmentation
img_scale_factor (Tensor): (w_scale, h_scale)
img_crop_offset (Tensor): (w_offset, h_offset) offset used to crop
image during data augmentation
pcd_trans_factor ([type]): Translation of points in augmentation
pcd_scale_factor (float): Scale factor of points during
data augmentation
pcd_flip (bool): Whether the points are flipped.
img_flip (bool): Whether the image is flipped.
img_pad_shape (tuple[int]): int tuple indicates the h & w after
padding, this is necessary to obtain features in feature map
img_shape (tuple[int]): int tuple indicates the h & w before padding
after scaling, this is necessary for flipping coordinates
return:
aligned (bool, optional): Whether use bilinear interpolation when
sampling image features for each point. Defaults to True.
padding_mode (str, optional): Padding mode when padding values for
features of out-of-image points. Defaults to 'zeros'.
align_corners (bool, optional): Whether to align corners when
sampling image features for each point. Defaults to True.
Returns:
(Tensor): NxC image features sampled by point coordinates
"""
# aug order: flip -> trans -> scale -> rot
......@@ -97,7 +113,36 @@ def point_sample(
@FUSION_LAYERS.register_module()
class PointFusion(nn.Module):
"""Fuse image features from fused single scale features
"""Fuse image features from multi-scale features
Args:
img_channels (list[int] | int): Channels of image features.
It could be a list if the input is multi-scale image features.
pts_channels (int): Channels of point features
mid_channels (int): Channels of middle layers
out_channels (int): Channels of output fused features
img_levels (int, optional): Number of image levels. Defaults to 3.
conv_cfg (dict, optional): Dict config of conv layers of middle
layers. Defaults to None.
norm_cfg (dict, optional): Dict config of norm layers of middle
layers. Defaults to None.
act_cfg (dict, optional): Dict config of activatation layers.
Defaults to None.
activate_out (bool, optional): Whether to apply relu activation
to output features. Defaults to True.
fuse_out (bool, optional): Whether apply conv layer to the fused
features. Defaults to False.
dropout_ratio (int, float, optional): Dropout ratio of image
features to prevent overfitting. Defaults to 0.
aligned (bool, optional): Whether apply aligned feature fusion.
Defaults to True.
align_corners (bool, optional): Whether to align corner when
sampling features according to points. Defaults to True.
padding_mode (str, optional): Mode used to pad the features of
points that do not have corresponding image features.
Defaults to 'zeros'.
lateral_conv (bool, optional): Whether to apply lateral convs
to image features. Defaults to True.
"""
def __init__(self,
......@@ -179,15 +224,20 @@ class PointFusion(nn.Module):
if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, distribution='uniform')
def forward(self, img_feats, pts, pts_feats, img_meta):
"""
img_feats (List[Tensor]): img features
pts: [List[Tensor]]: a batch of points with shape Nx3
pts_feats (Tensor): a tensor consist of point features of the
total batch
def forward(self, img_feats, pts, pts_feats, img_metas):
"""Forward function
Args:
img_feats (list[Tensor]): img features
pts: [list[Tensor]]: a batch of points with shape Nx3
pts_feats (Tensor): a tensor consist of point features of the
total batch
img_metas (list[dict]): meta information of images
Returns:
torch.Tensor: fused features of each point.
"""
img_pts = self.obtain_mlvl_feats(img_feats, pts, img_meta)
img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas)
img_pre_fuse = self.img_transform(img_pts)
if self.training and self.dropout_ratio > 0:
img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio)
......@@ -201,7 +251,7 @@ class PointFusion(nn.Module):
return fuse_out
def obtain_mlvl_feats(self, img_feats, pts, img_meta):
def obtain_mlvl_feats(self, img_feats, pts, img_metas):
if self.lateral_convs is not None:
img_ins = [
lateral_conv(img_feats[i])
......@@ -211,7 +261,7 @@ class PointFusion(nn.Module):
img_ins = img_feats
img_feats_per_point = []
# Sample multi-level features
for i in range(len(img_meta)):
for i in range(len(img_metas)):
mlvl_img_feats = []
for level in range(len(self.img_levels)):
if torch.isnan(img_ins[level][i:i + 1]).any():
......@@ -219,7 +269,7 @@ class PointFusion(nn.Module):
pdb.set_trace()
mlvl_img_feats.append(
self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3],
img_meta[i]))
img_metas[i]))
mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1)
img_feats_per_point.append(mlvl_img_feats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment