Unverified Commit 9073a3b5 authored by Tai-Wang's avatar Tai-Wang Committed by GitHub
Browse files

[Refactor] Support imvoxelnet at SUN RGB-D on 1.x branch (#2141)

* Support imvoxelnet@sunrgbd on 1.x branch

* Add unit tests

* Update README.md

* Update imvoxelnet_2xb4_sunrgbd-3d-10class.py

* Add typehints

* Fix lint

* Fix BC-breaking caused by updated keys

* Add coord_type in the imvoxelnet kitti config
parent bd1525ec
......@@ -26,6 +26,12 @@ Results for SUN RGB-D, ScanNet and nuScenes are currently available in ImVoxelNe
| :--------------------------------------------: | :---: | :-----: | :------: | :------------: | :---: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [ResNet-50](./imvoxelnet_8xb4_kitti-3d-car.py) | Car | 3x | | | 17.26 | [model](https://download.openmmlab.com/mmdetection3d/v1.0.0_models/imvoxelnet/imvoxelnet_4x8_kitti-3d-car/imvoxelnet_4x8_kitti-3d-car_20210830_003014-3d0ffdf4.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.0.0_models/imvoxelnet/imvoxelnet_4x8_kitti-3d-car/imvoxelnet_4x8_kitti-3d-car_20210830_003014.log.json) |
### SUN RGB-D
| Backbone | Lr schd | Mem (GB) | Inf time (fps) | mAP@0.25 | mAP@0.5 | Download |
| :-------------------------------------------------: | :-----: | :------: | :------------: | :------: | :-----: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [ResNet-50](./imvoxelnet_4x2_sunrgbd-3d-10class.py) | 2x | 7.2 | 22.5 | 40.96 | 13.50 | [model](https://download.openmmlab.com/mmdetection3d/v1.0.0_models/imvoxelnet/imvoxelnet_4x2_sunrgbd-3d-10class/imvoxelnet_4x2_sunrgbd-3d-10class_20220809_184416-29ca7d2e.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.0.0_models/imvoxelnet/imvoxelnet_4x2_sunrgbd-3d-10class/imvoxelnet_4x2_sunrgbd-3d-10class_20220809_184416.log.json) |
## Citation
```latex
......
_base_ = [
'../_base_/schedules/mmdet-schedule-1x.py', '../_base_/default_runtime.py'
]
prior_generator = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-3.2, -0.2, -2.28, 3.2, 6.2, 0.28]],
rotations=[.0])
model = dict(
type='ImVoxelNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
neck_3d=dict(
type='IndoorImVoxelNeck',
in_channels=256,
out_channels=128,
n_blocks=[1, 1, 1]),
bbox_head=dict(
type='ImVoxelHead',
n_classes=10,
n_levels=3,
n_channels=128,
n_reg_outs=7,
pts_assign_threshold=27,
pts_center_threshold=18,
prior_generator=prior_generator),
prior_generator=prior_generator,
n_voxels=[40, 40, 16],
coord_type='DEPTH',
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01))
dataset_type = 'SUNRGBDDataset'
data_root = 'data/sunrgbd/'
class_names = [
'bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
'night_stand', 'bookshelf', 'bathtub'
]
metainfo = dict(CLASSES=class_names)
file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/sunrgbd/':
# 's3://openmmlab/datasets/detection3d/sunrgbd_processed/',
# 'data/sunrgbd/':
# 's3://openmmlab/datasets/detection3d/sunrgbd_processed/'
# }))
train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='RandomResize', scale=[(512, 384), (768, 576)], keep_ratio=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='Pack3DDetInputs', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(640, 480), keep_ratio=True),
dict(type='Pack3DDetInputs', keys=['img'])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=2,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='sunrgbd_infos_train.pkl',
pipeline=train_pipeline,
test_mode=False,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo)))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='sunrgbd_infos_val.pkl',
pipeline=test_pipeline,
test_mode=True,
box_type_3d='Depth',
metainfo=metainfo))
test_dataloader = val_dataloader
val_evaluator = dict(
type='IndoorMetric',
ann_file=data_root + 'sunrgbd_infos_val.pkl',
metric='bbox')
test_evaluator = val_evaluator
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2))
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
# hooks
default_hooks = dict(checkpoint=dict(type='CheckpointHook', max_keep_ckpts=1))
# runtime
find_unused_parameters = True # only 1 of 4 FPN outputs is used
......@@ -52,7 +52,8 @@ model = dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
n_voxels=[216, 248, 12],
anchor_generator=dict(
coord_type='LIDAR',
prior_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-0.16, -39.68, -3.08, 68.96, 39.68, 0.76]],
rotations=[.0]),
......
......@@ -9,6 +9,7 @@ from .fcaf3d_head import FCAF3DHead
from .fcos_mono3d_head import FCOSMono3DHead
from .free_anchor3d_head import FreeAnchor3DHead
from .groupfree3d_head import GroupFree3DHead
from .imvoxel_head import ImVoxelHead
from .monoflex_head import MonoFlexHead
from .parta2_rpn_head import PartA2RPNHead
from .pgd_head import PGDHead
......@@ -23,5 +24,5 @@ __all__ = [
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead',
'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead',
'MonoFlexHead', 'Base3DDenseHead', 'FCAF3DHead'
'MonoFlexHead', 'Base3DDenseHead', 'FCAF3DHead', 'ImVoxelHead'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
from mmcv.cnn import Scale
from mmcv.ops import nms3d, nms3d_normal
from mmdet.models.utils import multi_apply
from mmdet.utils import reduce_mean
from mmengine.config import ConfigDict
from mmengine.model import BaseModule, bias_init_with_prob, normal_init
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d.utils import rotation_3d_in_axis
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing_utils import (ConfigType, InstanceList,
OptConfigType, OptInstanceList)
@MODELS.register_module()
class ImVoxelHead(BaseModule):
r"""`ImVoxelNet<https://arxiv.org/abs/2106.01178>`_ head for indoor
datasets.
Args:
n_classes (int): Number of classes.
n_levels (int): Number of feature levels.
n_channels (int): Number of channels in input tensors.
n_reg_outs (int): Number of regression layer channels.
pts_assign_threshold (int): Min number of location per box to
be assigned with.
pts_center_threshold (int): Max number of locations per box to
be assigned with.
center_loss (dict, optional): Config of centerness loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
bbox_loss (dict, optional): Config of bbox loss.
Default: dict(type='RotatedIoU3DLoss').
cls_loss (dict, optional): Config of classification loss.
Default: dict(type='FocalLoss').
train_cfg (dict, optional): Config for train stage. Defaults to None.
test_cfg (dict, optional): Config for test stage. Defaults to None.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def __init__(self,
n_classes: int,
n_levels: int,
n_channels: int,
n_reg_outs: int,
pts_assign_threshold: int,
pts_center_threshold: int,
prior_generator: ConfigType,
center_loss: ConfigType = dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=True),
bbox_loss: ConfigType = dict(type='RotatedIoU3DLoss'),
cls_loss: ConfigType = dict(type='mmdet.FocalLoss'),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptConfigType = None):
super(ImVoxelHead, self).__init__(init_cfg)
self.pts_assign_threshold = pts_assign_threshold
self.pts_center_threshold = pts_center_threshold
self.prior_generator = TASK_UTILS.build(prior_generator)
self.center_loss = MODELS.build(center_loss)
self.bbox_loss = MODELS.build(bbox_loss)
self.cls_loss = MODELS.build(cls_loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers(n_channels, n_reg_outs, n_classes, n_levels)
def _init_layers(self, n_channels, n_reg_outs, n_classes, n_levels):
"""Initialize neural network layers of the head."""
self.conv_center = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False)
self.conv_reg = nn.Conv3d(
n_channels, n_reg_outs, 3, padding=1, bias=False)
self.conv_cls = nn.Conv3d(n_channels, n_classes, 3, padding=1)
self.scales = nn.ModuleList([Scale(1.) for _ in range(n_levels)])
def init_weights(self):
"""Initialize all layer weights."""
normal_init(self.conv_center, std=.01)
normal_init(self.conv_reg, std=.01)
normal_init(self.conv_cls, std=.01, bias=bias_init_with_prob(.01))
def _forward_single(self, x: Tensor, scale: Scale):
"""Forward pass per level.
Args:
x (Tensor): Per level 3d neck output tensor.
scale (mmcv.cnn.Scale): Per level multiplication weight.
Returns:
tuple[Tensor]: Centerness, bbox and classification predictions.
"""
reg_final = self.conv_reg(x)
reg_distance = torch.exp(scale(reg_final[:, :6]))
reg_angle = reg_final[:, 6:]
bbox_pred = torch.cat((reg_distance, reg_angle), dim=1)
return self.conv_center(x), bbox_pred, self.conv_cls(x)
def forward(self, x: Tensor):
"""Forward function.
Args:
x (list[Tensor]): Features from 3d neck.
Returns:
tuple[Tensor]: Centerness, bbox and classification predictions.
"""
return multi_apply(self._forward_single, x, self.scales)
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
**kwargs) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
valid_pred = x[-1]
outs = self(x[:-1])
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
loss_inputs = outs + (valid_pred, batch_gt_instances_3d,
batch_input_metas, batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_and_predict(self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
proposal_cfg: Optional[ConfigDict] = None,
**kwargs) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
Args:
x (tuple[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each image and
corresponding annotations.
proposal_cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
valid_pred = x[-1]
outs = self(x[:-1])
loss_inputs = outs + (valid_pred, batch_gt_instances_3d,
batch_input_metas, batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
predictions = self.predict_by_feat(
*outs,
valid_pred=valid_pred,
batch_input_metas=batch_input_metas,
cfg=proposal_cfg)
return losses, predictions
def predict(self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the 3D detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_pts_panoptic_seg` and
`gt_pts_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
"""
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
valid_pred = x[-1]
outs = self(x[:-1])
predictions = self.predict_by_feat(
*outs,
valid_pred=valid_pred,
batch_input_metas=batch_input_metas,
rescale=rescale)
return predictions
def _loss_by_feat_single(self, center_preds, bbox_preds, cls_preds,
valid_preds, input_meta, gt_bboxes, gt_labels):
"""Per scene loss function.
Args:
center_preds (list[Tensor]): Centerness predictions for all levels.
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
valid_preds (list[Tensor]): Valid mask predictions for all levels.
input_meta (dict): Scene meta info.
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
Returns:
tuple[Tensor]: Centerness, bbox, and classification loss values.
"""
points = self._get_points(center_preds)
center_targets, bbox_targets, cls_targets = self._get_targets(
points, gt_bboxes, gt_labels)
center_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1) for x in center_preds])
bbox_preds = torch.cat([
x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in bbox_preds
])
cls_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in cls_preds])
valid_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1) for x in valid_preds])
points = torch.cat(points)
# cls loss
pos_inds = torch.nonzero(
torch.logical_and(cls_targets >= 0, valid_preds)).squeeze(1)
n_pos = points.new_tensor(len(pos_inds))
n_pos = max(reduce_mean(n_pos), 1.)
if torch.any(valid_preds):
cls_loss = self.cls_loss(
cls_preds[valid_preds],
cls_targets[valid_preds],
avg_factor=n_pos)
else:
cls_loss = cls_preds[valid_preds].sum()
# bbox and centerness losses
pos_center_preds = center_preds[pos_inds]
pos_bbox_preds = bbox_preds[pos_inds]
if len(pos_inds) > 0:
pos_center_targets = center_targets[pos_inds]
pos_bbox_targets = bbox_targets[pos_inds]
pos_points = points[pos_inds]
center_loss = self.center_loss(
pos_center_preds, pos_center_targets, avg_factor=n_pos)
bbox_loss = self.bbox_loss(
self._bbox_pred_to_bbox(pos_points, pos_bbox_preds),
pos_bbox_targets,
weight=pos_center_targets,
avg_factor=pos_center_targets.sum())
else:
center_loss = pos_center_preds.sum()
bbox_loss = pos_bbox_preds.sum()
return center_loss, bbox_loss, cls_loss
def loss_by_feat(self,
center_preds: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]],
valid_pred: Tensor,
batch_gt_instances_3d: InstanceList,
batch_input_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
**kwargs) -> dict:
"""Per scene loss function.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes_3d``、`
`labels_3d``、``depths``、``centers_2d`` and attributes.
batch_input_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict: Centerness, bbox, and classification loss values.
"""
valid_preds = self._upsample_valid_preds(valid_pred, center_preds)
center_losses, bbox_losses, cls_losses = [], [], []
for i in range(len(batch_input_metas)):
center_loss, bbox_loss, cls_loss = self._loss_by_feat_single(
center_preds=[x[i] for x in center_preds],
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
valid_preds=[x[i] for x in valid_preds],
input_meta=batch_input_metas[i],
gt_bboxes=batch_gt_instances_3d[i].bboxes_3d,
gt_labels=batch_gt_instances_3d[i].labels_3d)
center_losses.append(center_loss)
bbox_losses.append(bbox_loss)
cls_losses.append(cls_loss)
return dict(
center_loss=torch.mean(torch.stack(center_losses)),
bbox_loss=torch.mean(torch.stack(bbox_losses)),
cls_loss=torch.mean(torch.stack(cls_losses)))
def _predict_by_feat_single(self, center_preds: List[Tensor],
bbox_preds: List[Tensor],
cls_preds: List[Tensor],
valid_preds: List[Tensor],
input_meta: dict) -> InstanceData:
"""Generate boxes for single sample.
Args:
center_preds (list[Tensor]): Centerness predictions for all levels.
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
valid_preds (tuple[Tensor]): Upsampled valid masks for all feature
levels.
input_meta (dict): Scene meta info.
Returns:
tuple[Tensor]: Predicted bounding boxes, scores and labels.
"""
points = self._get_points(center_preds)
mlvl_bboxes, mlvl_scores = [], []
for center_pred, bbox_pred, cls_pred, valid_pred, point in zip(
center_preds, bbox_preds, cls_preds, valid_preds, points):
center_pred = center_pred.permute(1, 2, 3, 0).reshape(-1, 1)
bbox_pred = bbox_pred.permute(1, 2, 3,
0).reshape(-1, bbox_pred.shape[0])
cls_pred = cls_pred.permute(1, 2, 3,
0).reshape(-1, cls_pred.shape[0])
valid_pred = valid_pred.permute(1, 2, 3, 0).reshape(-1, 1)
scores = cls_pred.sigmoid() * center_pred.sigmoid() * valid_pred
max_scores, _ = scores.max(dim=1)
if len(scores) > self.test_cfg.nms_pre > 0:
_, ids = max_scores.topk(self.test_cfg.nms_pre)
bbox_pred = bbox_pred[ids]
scores = scores[ids]
point = point[ids]
bboxes = self._bbox_pred_to_bbox(point, bbox_pred)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
bboxes = torch.cat(mlvl_bboxes)
scores = torch.cat(mlvl_scores)
bboxes, scores, labels = self._single_scene_multiclass_nms(
bboxes, scores, input_meta)
bboxes = input_meta['box_type_3d'](
bboxes,
box_dim=bboxes.shape[1],
with_yaw=bboxes.shape[1] == 7,
origin=(.5, .5, .5))
results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
return results
def predict_by_feat(self, center_preds: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]], valid_pred: Tensor,
batch_input_metas: List[dict],
**kwargs) -> List[InstanceData]:
"""Generate boxes for all scenes.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_input_metas (list[dict]): Meta infos for all scenes.
Returns:
list[tuple[Tensor]]: Predicted bboxes, scores, and labels for
all scenes.
"""
valid_preds = self._upsample_valid_preds(valid_pred, center_preds)
results = []
for i in range(len(batch_input_metas)):
results.append(
self._predict_by_feat_single(
center_preds=[x[i] for x in center_preds],
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
valid_preds=[x[i] for x in valid_preds],
input_meta=batch_input_metas[i]))
return results
@staticmethod
def _upsample_valid_preds(valid_pred, features):
"""Upsample valid mask predictions.
Args:
valid_pred (Tensor): Valid mask prediction.
features (Tensor): Feature tensor.
Returns:
tuple[Tensor]: Upsampled valid masks for all feature levels.
"""
return [
nn.Upsample(size=x.shape[-3:],
mode='trilinear')(valid_pred).round().bool()
for x in features
]
def _get_points(self, features):
"""Generate final locations.
Args:
features (list[Tensor]): Feature tensors for all feature levels.
Returns:
list(Tensor): Final locations for all feature levels.
"""
points = []
for x in features:
n_voxels = x.size()[-3:][::-1]
points.append(
self.prior_generator.grid_anchors(
[n_voxels],
device=x.device)[0][:, :3].reshape(n_voxels +
(3, )).permute(
2, 1, 0,
3).reshape(-1, 3))
return points
@staticmethod
def _bbox_pred_to_bbox(points, bbox_pred):
"""Transform predicted bbox parameters to bbox.
Args:
points (Tensor): Final locations of shape (N, 3).
bbox_pred (Tensor): Predicted bbox parameters of shape (N, 7).
Returns:
Tensor: Transformed 3D box of shape (N, 7).
"""
if bbox_pred.shape[0] == 0:
return bbox_pred
# dx_min, dx_max, dy_min, dy_max, dz_min, dz_max, alpha ->
# x_center, y_center, z_center, w, l, h, alpha
shift = torch.stack(((bbox_pred[:, 1] - bbox_pred[:, 0]) / 2,
(bbox_pred[:, 3] - bbox_pred[:, 2]) / 2,
(bbox_pred[:, 5] - bbox_pred[:, 4]) / 2),
dim=-1).view(-1, 1, 3)
shift = rotation_3d_in_axis(shift, bbox_pred[:, 6], axis=2)[:, 0, :]
center = points + shift
size = torch.stack(
(bbox_pred[:, 0] + bbox_pred[:, 1], bbox_pred[:, 2] +
bbox_pred[:, 3], bbox_pred[:, 4] + bbox_pred[:, 5]),
dim=-1)
return torch.cat((center, size, bbox_pred[:, 6:7]), dim=-1)
# The function is directly copied from FCAF3DHead.
@staticmethod
def _get_face_distances(points, boxes):
"""Calculate distances from point to box faces.
Args:
points (Tensor): Final locations of shape (N_points, N_boxes, 3).
boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7)
Returns:
Tensor: Face distances of shape (N_points, N_boxes, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
"""
shift = torch.stack(
(points[..., 0] - boxes[..., 0], points[..., 1] - boxes[..., 1],
points[..., 2] - boxes[..., 2]),
dim=-1).permute(1, 0, 2)
shift = rotation_3d_in_axis(
shift, -boxes[0, :, 6], axis=2).permute(1, 0, 2)
centers = boxes[..., :3] + shift
dx_min = centers[..., 0] - boxes[..., 0] + boxes[..., 3] / 2
dx_max = boxes[..., 0] + boxes[..., 3] / 2 - centers[..., 0]
dy_min = centers[..., 1] - boxes[..., 1] + boxes[..., 4] / 2
dy_max = boxes[..., 1] + boxes[..., 4] / 2 - centers[..., 1]
dz_min = centers[..., 2] - boxes[..., 2] + boxes[..., 5] / 2
dz_max = boxes[..., 2] + boxes[..., 5] / 2 - centers[..., 2]
return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max),
dim=-1)
# The function is directly copied from FCAF3DHead.
@staticmethod
def _get_centerness(face_distances):
"""Compute point centerness w.r.t containing box.
Args:
face_distances (Tensor): Face distances of shape (B, N, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
Returns:
Tensor: Centerness of shape (B, N).
"""
x_dims = face_distances[..., [0, 1]]
y_dims = face_distances[..., [2, 3]]
z_dims = face_distances[..., [4, 5]]
centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \
y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \
z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0]
return torch.sqrt(centerness_targets)
# The function is directly copied from FCAF3DHead.
@torch.no_grad()
def _get_targets(self, points, gt_bboxes, gt_labels):
"""Compute targets for final locations for a single scene.
Args:
points (list[Tensor]): Final locations for all levels.
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
Returns:
tuple[Tensor]: Centerness, bbox and classification
targets for all locations.
"""
float_max = points[0].new_tensor(1e8)
n_levels = len(points)
levels = torch.cat([
points[i].new_tensor(i).expand(len(points[i]))
for i in range(len(points))
])
points = torch.cat(points)
gt_bboxes = gt_bboxes.to(points.device)
n_points = len(points)
n_boxes = len(gt_bboxes)
volumes = gt_bboxes.volume.unsqueeze(0).expand(n_points, n_boxes)
# condition 1: point inside box
boxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
dim=1)
boxes = boxes.expand(n_points, n_boxes, 7)
points = points.unsqueeze(1).expand(n_points, n_boxes, 3)
face_distances = self._get_face_distances(points, boxes)
inside_box_condition = face_distances.min(dim=-1).values > 0
# condition 2: positive points per level >= limit
# calculate positive points per scale
n_pos_points_per_level = []
for i in range(n_levels):
n_pos_points_per_level.append(
torch.sum(inside_box_condition[levels == i], dim=0))
# find best level
n_pos_points_per_level = torch.stack(n_pos_points_per_level, dim=0)
lower_limit_mask = n_pos_points_per_level < self.pts_assign_threshold
lower_index = torch.argmax(lower_limit_mask.int(), dim=0) - 1
lower_index = torch.where(lower_index < 0, 0, lower_index)
all_upper_limit_mask = torch.all(
torch.logical_not(lower_limit_mask), dim=0)
best_level = torch.where(all_upper_limit_mask, n_levels - 1,
lower_index)
# keep only points with best level
best_level = best_level.expand(n_points, n_boxes)
levels = torch.unsqueeze(levels, 1).expand(n_points, n_boxes)
level_condition = best_level == levels
# condition 3: limit topk points per box by centerness
centerness = self._get_centerness(face_distances)
centerness = torch.where(inside_box_condition, centerness,
torch.ones_like(centerness) * -1)
centerness = torch.where(level_condition, centerness,
torch.ones_like(centerness) * -1)
top_centerness = torch.topk(
centerness,
min(self.pts_center_threshold + 1, len(centerness)),
dim=0).values[-1]
topk_condition = centerness > top_centerness.unsqueeze(0)
# condition 4: min volume box per point
volumes = torch.where(inside_box_condition, volumes, float_max)
volumes = torch.where(level_condition, volumes, float_max)
volumes = torch.where(topk_condition, volumes, float_max)
min_volumes, min_inds = volumes.min(dim=1)
center_targets = centerness[torch.arange(n_points), min_inds]
bbox_targets = boxes[torch.arange(n_points), min_inds]
if not gt_bboxes.with_yaw:
bbox_targets = bbox_targets[:, :-1]
cls_targets = gt_labels[min_inds]
cls_targets = torch.where(min_volumes == float_max, -1, cls_targets)
return center_targets, bbox_targets, cls_targets
# Originally ImVoxelNet utilizes 2d nms as mmdetection3d didn't
# support 3d nms. But since mmcv==1.5.2 we simply use nms3d here.
# The function is directly copied from FCAF3DHead.
def _single_scene_multiclass_nms(self, bboxes, scores, input_meta):
"""Multi-class nms for a single scene.
Args:
bboxes (Tensor): Predicted boxes of shape (N_boxes, 6) or
(N_boxes, 7).
scores (Tensor): Predicted scores of shape (N_boxes, N_classes).
input_meta (dict): Scene meta data.
Returns:
tuple[Tensor]: Predicted bboxes, scores and labels.
"""
n_classes = scores.shape[1]
with_yaw = bboxes.shape[1] == 7
nms_bboxes, nms_scores, nms_labels = [], [], []
for i in range(n_classes):
ids = scores[:, i] > self.test_cfg.score_thr
if not ids.any():
continue
class_scores = scores[ids, i]
class_bboxes = bboxes[ids]
if with_yaw:
nms_function = nms3d
else:
class_bboxes = torch.cat(
(class_bboxes, torch.zeros_like(class_bboxes[:, :1])),
dim=1)
nms_function = nms3d_normal
nms_ids = nms_function(class_bboxes, class_scores,
self.test_cfg.iou_thr)
nms_bboxes.append(class_bboxes[nms_ids])
nms_scores.append(class_scores[nms_ids])
nms_labels.append(
bboxes.new_full(
class_scores[nms_ids].shape, i, dtype=torch.long))
if len(nms_bboxes):
nms_bboxes = torch.cat(nms_bboxes, dim=0)
nms_scores = torch.cat(nms_scores, dim=0)
nms_labels = torch.cat(nms_labels, dim=0)
else:
nms_bboxes = bboxes.new_zeros((0, bboxes.shape[1]))
nms_scores = bboxes.new_zeros((0, ))
nms_labels = bboxes.new_zeros((0, ))
if with_yaw:
box_dim = 7
else:
box_dim = 6
nms_bboxes = nms_bboxes[:, :box_dim]
return nms_bboxes, nms_scores, nms_labels
......@@ -7,6 +7,7 @@ from mmengine.structures import InstanceData
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d import get_proj_mat_by_coord_type
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, OptConfigType, OptInstanceList
......@@ -20,9 +21,11 @@ class ImVoxelNet(Base3DDetector):
neck (:obj:`ConfigDict` or dict): The neck config.
neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
prior_generator (:obj:`ConfigDict` or dict): The prior points
generator config.
n_voxels (list): Number of voxels along x, y, z axis.
anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
config.
coord_type (str): The type of coordinates of points cloud:
'DEPTH', 'LIDAR', or 'CAMERA'.
train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
training hyper-parameters. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
......@@ -39,8 +42,9 @@ class ImVoxelNet(Base3DDetector):
neck: ConfigType,
neck_3d: ConfigType,
bbox_head: ConfigType,
prior_generator: ConfigType,
n_voxels: List,
anchor_generator: ConfigType,
coord_type: str,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
......@@ -53,8 +57,9 @@ class ImVoxelNet(Base3DDetector):
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.prior_generator = TASK_UTILS.build(prior_generator)
self.n_voxels = n_voxels
self.anchor_generator = TASK_UTILS.build(anchor_generator)
self.coord_type = coord_type
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -62,6 +67,8 @@ class ImVoxelNet(Base3DDetector):
batch_data_samples: SampleList):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
-> 3d neck -> bbox_head.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
......@@ -72,7 +79,9 @@ class ImVoxelNet(Base3DDetector):
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
Tuple:
- torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z).
- torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z).
"""
img = batch_inputs_dict['imgs']
batch_img_metas = [
......@@ -80,9 +89,9 @@ class ImVoxelNet(Base3DDetector):
]
x = self.backbone(img)
x = self.neck(x)[0]
points = self.anchor_generator.grid_anchors(
[self.n_voxels[::-1]], device=img.device)[0][:, :3]
volumes = []
points = self.prior_generator.grid_anchors([self.n_voxels[::-1]],
device=img.device)[0][:, :3]
volumes, valid_preds = [], []
for feature, img_meta in zip(x, batch_img_metas):
img_scale_factor = (
points.new_tensor(img_meta['scale_factor'][:2])
......@@ -91,13 +100,14 @@ class ImVoxelNet(Base3DDetector):
img_crop_offset = (
points.new_tensor(img_meta['img_crop_offset'])
if 'img_crop_offset' in img_meta.keys() else 0)
lidar2img = points.new_tensor(img_meta['lidar2img'])
proj_mat = points.new_tensor(
get_proj_mat_by_coord_type(img_meta, self.coord_type))
volume = point_sample(
img_meta,
img_features=feature[None, ...],
points=points,
proj_mat=lidar2img,
coord_type='LIDAR',
proj_mat=points.new_tensor(proj_mat),
coord_type=self.coord_type,
img_scale_factor=img_scale_factor,
img_crop_offset=img_crop_offset,
img_flip=img_flip,
......@@ -106,9 +116,11 @@ class ImVoxelNet(Base3DDetector):
aligned=False)
volumes.append(
volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
valid_preds.append(
~torch.all(volumes[-1] == 0, dim=0, keepdim=True))
x = torch.stack(volumes)
x = self.neck_3d(x)
return x
return x, torch.stack(valid_preds).float()
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> Union[dict, list]:
......@@ -126,8 +138,12 @@ class ImVoxelNet(Base3DDetector):
Returns:
dict: A dictionary of loss components.
"""
x = self.extract_feat(batch_inputs_dict, batch_data_samples)
x, valid_preds = self.extract_feat(batch_inputs_dict,
batch_data_samples)
# For indoor datasets ImVoxelNet uses ImVoxelHead that handles
# mask of visible voxels.
if self.coord_type == 'DEPTH':
x += (valid_preds, )
losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
return losses
......@@ -159,8 +175,14 @@ class ImVoxelNet(Base3DDetector):
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
x = self.extract_feat(batch_inputs_dict, batch_data_samples)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
x, valid_preds = self.extract_feat(batch_inputs_dict,
batch_data_samples)
# For indoor datasets ImVoxelNet uses ImVoxelHead that handles
# mask of visible voxels.
if self.coord_type == 'DEPTH':
x += (valid_preds, )
results_list = \
self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list)
return predictions
......@@ -182,7 +204,12 @@ class ImVoxelNet(Base3DDetector):
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward.
"""
x = self.extract_feat(batch_inputs_dict, batch_data_samples)
x, valid_preds = self.extract_feat(batch_inputs_dict,
batch_data_samples)
# For indoor datasets ImVoxelNet uses ImVoxelHead that handles
# mask of visible voxels.
if self.coord_type == 'DEPTH':
x += (valid_preds, )
results = self.bbox_head.forward(x)
return results
......
......@@ -2,10 +2,11 @@
from mmdet.models.necks.fpn import FPN
from .dla_neck import DLANeck
from .imvoxel_neck import OutdoorImVoxelNeck
from .imvoxel_neck import IndoorImVoxelNeck, OutdoorImVoxelNeck
from .pointnet2_fp_neck import PointNetFPNeck
from .second_fpn import SECONDFPN
__all__ = [
'FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'PointNetFPNeck', 'DLANeck'
'FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'PointNetFPNeck', 'DLANeck',
'IndoorImVoxelNeck'
]
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import nn
from mmdet3d.registry import MODELS
@MODELS.register_module()
class OutdoorImVoxelNeck(nn.Module):
class OutdoorImVoxelNeck(BaseModule):
"""Neck for ImVoxelNet outdoor scenario.
Args:
in_channels (int): Input channels of multi-scale feature map.
out_channels (int): Output channels of multi-scale feature map.
in_channels (int): Number of channels in an input tensor.
out_channels (int): Number of channels in all output tensors.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
super(OutdoorImVoxelNeck, self).__init__()
self.model = nn.Sequential(
ResModule(in_channels),
ResModule(in_channels, in_channels),
ConvModule(
in_channels=in_channels,
out_channels=in_channels * 2,
......@@ -27,7 +28,7 @@ class OutdoorImVoxelNeck(nn.Module):
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d'),
act_cfg=dict(type='ReLU', inplace=True)),
ResModule(in_channels * 2),
ResModule(in_channels * 2, in_channels * 2),
ConvModule(
in_channels=in_channels * 2,
out_channels=in_channels * 4,
......@@ -37,7 +38,7 @@ class OutdoorImVoxelNeck(nn.Module):
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d'),
act_cfg=dict(type='ReLU', inplace=True)),
ResModule(in_channels * 4),
ResModule(in_channels * 4, in_channels * 4),
ConvModule(
in_channels=in_channels * 4,
out_channels=out_channels,
......@@ -66,31 +67,148 @@ class OutdoorImVoxelNeck(nn.Module):
pass
@MODELS.register_module()
class IndoorImVoxelNeck(BaseModule):
"""Neck for ImVoxelNet outdoor scenario.
Args:
in_channels (int): Number of channels in an input tensor.
out_channels (int): Number of channels in all output tensors.
n_blocks (list[int]): Number of blocks for each feature level.
"""
def __init__(self, in_channels, out_channels, n_blocks):
super(IndoorImVoxelNeck, self).__init__()
self.n_scales = len(n_blocks)
n_channels = in_channels
for i in range(len(n_blocks)):
stride = 1 if i == 0 else 2
self.__setattr__(f'down_layer_{i}',
self._make_layer(stride, n_channels, n_blocks[i]))
n_channels = n_channels * stride
if i > 0:
self.__setattr__(
f'up_block_{i}',
self._make_up_block(n_channels, n_channels // 2))
self.__setattr__(f'out_block_{i}',
self._make_block(n_channels, out_channels))
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): of shape (N, C_in, N_x, N_y, N_z).
Returns:
list[torch.Tensor]: of shape (N, C_out, N_xi, N_yi, N_zi).
"""
down_outs = []
for i in range(self.n_scales):
x = self.__getattr__(f'down_layer_{i}')(x)
down_outs.append(x)
outs = []
for i in range(self.n_scales - 1, -1, -1):
if i < self.n_scales - 1:
x = self.__getattr__(f'up_block_{i + 1}')(x)
x = down_outs[i] + x
out = self.__getattr__(f'out_block_{i}')(x)
outs.append(out)
return outs[::-1]
@staticmethod
def _make_layer(stride, n_channels, n_blocks):
"""Make a layer from several residual blocks.
Args:
stride (int): Stride of the first residual block.
n_channels (int): Number of channels of the first residual block.
n_blocks (int): Number of residual blocks.
Returns:
torch.nn.Module: With several residual blocks.
"""
blocks = []
for i in range(n_blocks):
if i == 0 and stride != 1:
blocks.append(ResModule(n_channels, n_channels * 2, stride))
n_channels = n_channels * 2
else:
blocks.append(ResModule(n_channels, n_channels))
return nn.Sequential(*blocks)
@staticmethod
def _make_block(in_channels, out_channels):
"""Make a convolutional block.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
Returns:
torch.nn.Module: Convolutional block.
"""
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True))
@staticmethod
def _make_up_block(in_channels, out_channels):
"""Make upsampling convolutional block.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
Returns:
torch.nn.Module: Upsampling convolutional block.
"""
return nn.Sequential(
nn.ConvTranspose3d(in_channels, out_channels, 2, 2, bias=False),
nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True))
class ResModule(nn.Module):
"""3d residual block for ImVoxelNeck.
Args:
n_channels (int): Input channels of a feature map.
in_channels (int): Number of channels in input tensor.
out_channels (int): Number of channels in output tensor.
stride (int, optional): Stride of the block. Defaults to 1.
"""
def __init__(self, n_channels):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv0 = ConvModule(
in_channels=n_channels,
out_channels=n_channels,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=1,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d'),
act_cfg=dict(type='ReLU', inplace=True))
self.conv1 = ConvModule(
in_channels=n_channels,
out_channels=n_channels,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d'),
act_cfg=None)
if stride != 1:
self.downsample = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d'),
act_cfg=None)
self.stride = stride
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
......@@ -105,6 +223,8 @@ class ResModule(nn.Module):
identity = x
x = self.conv0(x)
x = self.conv1(x)
x = identity + x
if self.stride != 1:
identity = self.downsample(identity)
x = x + identity
x = self.activation(x)
return x
......@@ -9,10 +9,10 @@ from mmdet3d.models.dense_heads import FCAF3DHead
from mmdet3d.testing import create_detector_inputs
class TestAnchor3DHead(TestCase):
class TestFCAF3DHead(TestCase):
def test_fcaf3d_head_loss(self):
"""Test anchor head loss when truth is empty and non-empty."""
"""Test fcaf3d head loss when truth is empty and non-empty."""
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
......
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from mmdet3d import * # noqa
from mmdet3d.models.dense_heads import ImVoxelHead
from mmdet3d.testing import create_detector_inputs
class TestImVoxelHead(TestCase):
def test_imvoxel_head_loss(self):
"""Test imvoxel head loss when truth is empty and non-empty."""
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
# build head
prior_generator = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-3.2, -0.2, -2.28, 3.2, 6.2, 0.28]],
rotations=[.0])
imvoxel_head = ImVoxelHead(
n_classes=1,
n_levels=1,
n_channels=32,
n_reg_outs=7,
pts_assign_threshold=27,
pts_center_threshold=18,
prior_generator=prior_generator,
center_loss=dict(type='mmdet.CrossEntropyLoss', use_sigmoid=True),
bbox_loss=dict(type='RotatedIoU3DLoss'),
cls_loss=dict(type='mmdet.FocalLoss'),
)
imvoxel_head = imvoxel_head.cuda()
# fake input of head
# (x, valid_preds)
x = [
torch.randn(1, 32, 10, 10, 4).cuda(),
torch.ones(1, 1, 10, 10, 4).cuda()
]
# fake annotation
num_gt_instance = 1
packed_inputs = create_detector_inputs(
with_points=False,
with_img=True,
img_size=(128, 128),
num_gt_instance=num_gt_instance,
with_pts_semantic_mask=False,
with_pts_instance_mask=False)
data_samples = [
sample.cuda() for sample in packed_inputs['data_samples']
]
losses = imvoxel_head.loss(x, data_samples)
print(losses)
self.assertGreaterEqual(losses['center_loss'], 0)
self.assertGreaterEqual(losses['bbox_loss'], 0)
self.assertGreaterEqual(losses['cls_loss'], 0)
......@@ -10,11 +10,12 @@ from mmdet3d.testing import (create_detector_inputs, get_detector_cfg,
class TestImVoxelNet(unittest.TestCase):
def test_imvoxelnet(self):
def test_imvoxelnet_kitti(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'ImVoxelNet')
DefaultScope.get_instance('test_ImVoxelNet', scope_name='mmdet3d')
DefaultScope.get_instance(
'test_imvoxelnet_kitti', scope_name='mmdet3d')
setup_seed(0)
imvoxel_net_cfg = get_detector_cfg(
'imvoxelnet/imvoxelnet_8xb4_kitti-3d-car.py')
......@@ -47,3 +48,42 @@ class TestImVoxelNet(unittest.TestCase):
self.assertGreaterEqual(losses['loss_cls'][0], 0)
self.assertGreaterEqual(losses['loss_bbox'][0], 0)
self.assertGreaterEqual(losses['loss_dir'][0], 0)
def test_imvoxelnet_sunrgbd(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'ImVoxelNet')
DefaultScope.get_instance(
'test_imvoxelnet_sunrgbd', scope_name='mmdet3d')
setup_seed(0)
imvoxel_net_cfg = get_detector_cfg(
'imvoxelnet/imvoxelnet_2xb4_sunrgbd-3d-10class.py')
model = MODELS.build(imvoxel_net_cfg)
num_gt_instance = 1
packed_inputs = create_detector_inputs(
with_points=False,
with_img=True,
img_size=(128, 128),
num_gt_instance=num_gt_instance,
with_pts_semantic_mask=False,
with_pts_instance_mask=False)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(**data, mode='loss')
self.assertGreaterEqual(losses['center_loss'], 0)
self.assertGreaterEqual(losses['bbox_loss'], 0)
self.assertGreaterEqual(losses['cls_loss'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment