"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "de431201b67655c9dbf4da83b6d5efb0cb1c5641"
Unverified Commit cbddb7f9 authored by Danila Rukhovich's avatar Danila Rukhovich Committed by GitHub
Browse files

[Feature] Add TR3D detector to projects (#2274)

* first tr3d commit

* all tr3d files added

* all tr3d is ok

* fix comments

* fix config imports and readme

* fix comments

* update links in readme

* fix lint
parent 77e0c654
# TR3D: Towards Real-Time Indoor 3D Object Detection
> [TR3D: Towards Real-Time Indoor 3D Object Detection](https://arxiv.org/abs/2302.02858)
## Abstract
Recently, sparse 3D convolutions have changed 3D object detection. Performing on par with the voting-based approaches, 3D CNNs are memory-efficient and scale to large scenes better. However, there is still room for improvement. With a conscious, practice-oriented approach to problem-solving, we analyze the performance of such methods and localize the weaknesses. Applying modifications that resolve the found issues one by one, we end up with TR3D: a fast fully-convolutional 3D object detection model trained end-to-end, that achieves state-of-the-art results on the standard benchmarks, ScanNet v2, SUN RGB-D, and S3DIS. Moreover, to take advantage of both point cloud and RGB inputs, we introduce an early fusion of 2D and 3D features. We employ our fusion module to make conventional 3D object detection methods multimodal and demonstrate an impressive boost in performance. Our model with early feature fusion, which we refer to as TR3D+FF, outperforms existing 3D object detection approaches on the SUN RGB-D dataset. Overall, besides being accurate, both TR3D and TR3D+FF models are lightweight, memory-efficient, and fast, thereby marking another milestone on the way toward real-time 3D object detection.
<div align="center">
<img src="https://user-images.githubusercontent.com/6030962/219644780-646516ec-a6c1-4ec5-9b8c-63bbc9702d05.png" width="800"/>
</div>
## Usage
Training and inference in this project were tested with `mmdet3d==1.1.0rc3`.
### Training commands
In MMDet3D's root directory, run the following command to train the model:
```bash
python tools/train.py projects/TR3D/configs/tr3d_1xb16_scannet-3d-18class.py
```
### Testing commands
In MMDet3D's root directory, run the following command to test the model:
```bash
python tools/test.py projects/TR3D/configs/tr3d_1xb16_scannet-3d-18class.py ${CHECKPOINT_PATH}
```
## Results and models
### ScanNet
| Backbone | Mem (GB) | Inf time (fps) | AP@0.25 | AP@0.5 | Download |
| :--------------------------------------------------------: | :------: | :------------: | :---------: | :---------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [MinkResNet34](./configs/tr3d_1xb16_scannet-3d-18class.py) | 8.6 | 23.7 | 72.9 (72.0) | 59.3 (57.4) | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_scannet-3d-18class/tr3d_1xb16_scannet-3d-18class.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_scannet-3d-18class/tr3d_1xb16_scannet-3d-18class.log.json) |
### SUN RGB-D
| Backbone | Mem (GB) | Inf time (fps) | AP@0.25 | AP@0.5 | Download |
| :--------------------------------------------------------: | :------: | :------------: | :---------: | :---------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [MinkResNet34](./configs/tr3d_1xb16_sunrgbd-3d-10class.py) | 3.8 | 27.5 | 67.1 (66.3) | 50.4 (49.6) | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_sunrgbd-3d-10class/tr3d_1xb16_sunrgbd-3d-10class.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_sunrgbd-3d-10class/tr3d_1xb16_sunrgbd-3d-10class.log.json) |
### S3DIS
| Backbone | Mem (GB) | Inf time (fps) | AP@0.25 | AP@0.5 | Download |
| :-----------------------------------------------------: | :------: | :------------: | :---------: | :---------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [MinkResNet34](./configs/tr3d_1xb16_s3dis-3d-5class.py) | 15.2 | 21.0 | 74.5 (72.1) | 51.7 (47.6) | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_s3dis-3d-5class/tr3d_1xb16_s3dis-3d-5class.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_s3dis-3d-5class/tr3d_1xb16_s3dis-3d-5class.log.json) |
**Note**
- We report the results across 5 train runs followed by 5 test runs. Median values are in round brackets.
- Inference time is given for a single NVidia GeForce RTX 4090 GPU.
## Citation
```latex
@article{rukhovich2023tr3d,
title={TR3D: Towards Real-Time Indoor 3D Object Detection},
author={Rukhovich, Danila and Vorontsova, Anna and Konushin, Anton},
journal={arXiv preprint arXiv:2302.02858},
year={2023}
}
```
## Checklist
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
- [x] Finish the code
- [x] Basic docstrings & proper citation
- [x] Test-time correctness
- [x] A full README
- [x] Milestone 2: Indicates a successful model implementation.
- [x] Training-time correctness
- [ ] Milestone 3: Good to be a part of our core package!
- [x] Type hints and docstrings
- [ ] Unit tests
- [ ] Code polishing
- [ ] Metafile.yml
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
_base_ = ['mmdet3d::_base_/default_runtime.py']
custom_imports = dict(imports=['projects.TR3D.tr3d'])
model = dict(
type='MinkSingleStage3DDetector',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict(
type='TR3DMinkResNet',
in_channels=3,
depth=34,
norm='batch',
num_planes=(64, 128, 128, 128)),
neck=dict(
type='TR3DNeck', in_channels=(64, 128, 128, 128), out_channels=128),
bbox_head=dict(
type='TR3DHead',
in_channels=128,
voxel_size=0.01,
pts_center_threshold=6,
num_reg_outs=6),
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=0.5, score_thr=0.01))
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.001, weight_decay=0.0001),
clip_grad=dict(max_norm=10, norm_type=2))
# learning rate
param_scheduler = dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)]
# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
_base_ = ['./tr3d.py', 'mmdet3d::_base_/datasets/s3dis-3d.py']
custom_imports = dict(imports=['projects.TR3D.tr3d'])
dataset_type = 'S3DISDataset'
data_root = 'data/s3dis/'
metainfo = dict(classes=('table', 'chair', 'sofa', 'bookcase', 'board'))
train_area = [1, 2, 3, 4, 6]
model = dict(bbox_head=dict(label2level=[1, 0, 1, 1, 0]))
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D'),
dict(type='PointSample', num_points=100000),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.1, 0.1, 0.1],
shift_height=False),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
dataset=dict(datasets=[
dict(
type=dataset_type,
data_root=data_root,
ann_file=f's3dis_infos_Area_{i}.pkl',
pipeline=train_pipeline,
filter_empty_gt=False,
metainfo=metainfo,
box_type_3d='Depth') for i in train_area
])))
_base_ = ['./tr3d.py', 'mmdet3d::_base_/datasets/scannet-3d.py']
custom_imports = dict(imports=['projects.TR3D.tr3d'])
model = dict(
bbox_head=dict(
label2level=[0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0]))
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D'),
dict(type='GlobalAlignment', rotation_axis=2),
# We do not sample 100k points for ScanNet, as very few scenes have
# significantly more then 100k points. So we sample 33 to 100% of them.
dict(type='TR3DPointSample', num_points=0.33),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.02, 0.02],
scale_ratio_range=[0.9, 1.1],
translation_std=[0.1, 0.1, 0.1],
shift_height=False),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
# We do not sample 100k points for ScanNet, as very few scenes have
# significantly more then 100k points. So it doesn't affect
# inference time and we can accept all points.
# dict(type='PointSample', num_points=100000),
dict(type='NormalizePointsColor', color_mean=None),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='RepeatDataset',
times=15,
dataset=dict(pipeline=train_pipeline, filter_empty_gt=False)))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
_base_ = ['./tr3d.py', 'mmdet3d::_base_/datasets/sunrgbd-3d.py']
custom_imports = dict(imports=['projects.TR3D.tr3d'])
model = dict(
bbox_head=dict(
num_reg_outs=8,
label2level=[1, 1, 1, 0, 0, 1, 0, 0, 1, 0],
bbox_loss=dict(
type='TR3DRotatedIoU3DLoss', mode='diou', reduction='none')))
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D'),
dict(type='PointSample', num_points=100000),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[.85, 1.15],
translation_std=[.1, .1, .1],
shift_height=False),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(type='PointSample', num_points=100000),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='RepeatDataset',
times=5,
dataset=dict(pipeline=train_pipeline, filter_empty_gt=False)))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
from .axis_aligned_iou_loss import TR3DAxisAlignedIoULoss
from .mink_resnet import TR3DMinkResNet
from .rotated_iou_loss import TR3DRotatedIoU3DLoss
from .tr3d_head import TR3DHead
from .tr3d_neck import TR3DNeck
from .transforms_3d import TR3DPointSample
__all__ = [
'TR3DAxisAlignedIoULoss', 'TR3DMinkResNet', 'TR3DRotatedIoU3DLoss',
'TR3DHead', 'TR3DNeck', 'TR3DPointSample'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from torch import nn as nn
from mmdet3d.models import axis_aligned_iou_loss
from mmdet3d.registry import MODELS
from mmdet3d.structures import AxisAlignedBboxOverlaps3D
@weighted_loss
def axis_aligned_diou_loss(pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the DIoU loss (1-DIoU) of two sets of axis aligned bounding
boxes. Note that predictions and targets are one-to-one corresponded.
Args:
pred (torch.Tensor): Bbox predictions with shape [..., 6]
(x1, y1, z1, x2, y2, z2).
target (torch.Tensor): Bbox targets (gt) with shape [..., 6]
(x1, y1, z1, x2, y2, z2).
Returns:
torch.Tensor: DIoU loss between predictions and targets.
"""
axis_aligned_iou = AxisAlignedBboxOverlaps3D()(
pred, target, is_aligned=True)
iou_loss = 1 - axis_aligned_iou
xp1, yp1, zp1, xp2, yp2, zp2 = pred.split(1, dim=-1)
xt1, yt1, zt1, xt2, yt2, zt2 = target.split(1, dim=-1)
xpc = (xp1 + xp2) / 2
ypc = (yp1 + yp2) / 2
zpc = (zp1 + zp2) / 2
xtc = (xt1 + xt2) / 2
ytc = (yt1 + yt2) / 2
ztc = (zt1 + zt2) / 2
r2 = (xpc - xtc)**2 + (ypc - ytc)**2 + (zpc - ztc)**2
x_min = torch.minimum(xp1, xt1)
x_max = torch.maximum(xp2, xt2)
y_min = torch.minimum(yp1, yt1)
y_max = torch.maximum(yp2, yt2)
z_min = torch.minimum(zp1, zt1)
z_max = torch.maximum(zp2, zt2)
c2 = (x_min - x_max)**2 + (y_min - y_max)**2 + (z_min - z_max)**2
diou_loss = iou_loss + (r2 / c2)[:, 0]
return diou_loss
@MODELS.register_module()
class TR3DAxisAlignedIoULoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of axis aligned bounding boxes. The only
difference with original AxisAlignedIoULoss is the addition of DIoU mode.
These classes should be merged in the future.
Args:
mode (str): 'iou' for intersection over union or 'diou' for
distance-iou loss. Defaults to 'iou'.
reduction (str): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to 'mean'.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
mode: str = 'iou',
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(TR3DAxisAlignedIoULoss, self).__init__()
assert mode in ['iou', 'diou']
self.loss = axis_aligned_iou_loss if mode == 'iou' \
else axis_aligned_diou_loss
assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[str] = None,
**kwargs) -> Tensor:
"""Forward function of loss calculation.
Args:
pred (Tensor): Bbox predictions with shape [..., 3].
target (Tensor): Bbox targets (gt) with shape [..., 3].
weight (Tensor, optional): Weight of loss.
Defaults to None.
avg_factor (float, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
Returns:
Tensor: IoU loss between predictions and targets.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if (weight is not None) and (not torch.any(weight > 0)) and (
reduction != 'none'):
return (pred * weight).sum()
return self.loss(
pred,
target,
weight=weight,
avg_factor=avg_factor,
reduction=reduction) * self.loss_weight
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
try:
import MinkowskiEngine as ME
except ImportError:
# Please follow getting_started.md to install MinkowskiEngine.
ME = SparseTensor = None
pass
from mmdet3d.models.backbones import MinkResNet
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TR3DMinkResNet(MinkResNet):
r"""Minkowski ResNet backbone. See `4D Spatio-Temporal ConvNets
<https://arxiv.org/abs/1904.08755>`_ for more details. The onle difference
with MinkResNet is the `norm` and `num_planes` parameters. These classes
should be merged in the future.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input channels, 3 for RGB.
num_stages (int): Resnet stages. Defaults to 4.
pool (bool): Whether to add max pooling after first conv.
Defaults to True.
norm (str): Norm type ('instance' or 'batch') for stem layer.
Usually ResNet implies BatchNorm but for some reason
original MinkResNet implies InstanceNorm. Defaults to 'instance'.
num_planes (tuple[int]): Number of planes per block before
block.expansion. Defaults to (64, 128, 256, 512).
"""
def __init__(self,
depth: int,
in_channels: int,
num_stages: int = 4,
pool: bool = True,
norm: str = 'instance',
num_planes: Tuple[int] = (64, 128, 256, 512)):
super(TR3DMinkResNet, self).__init__(depth, in_channels, num_stages,
pool)
block, stage_blocks = self.arch_settings[depth]
self.inplanes = 64
norm_layer = ME.MinkowskiInstanceNorm if norm == 'instance' else \
ME.MinkowskiBatchNorm
self.norm1 = norm_layer(self.inplanes)
for i in range(len(stage_blocks)):
setattr(
self, f'layer{i + 1}',
self._make_layer(
block, num_planes[i], stage_blocks[i], stride=2))
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmcv.ops.diff_iou_rotated import box2corners, oriented_box_intersection_2d
from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from torch import nn as nn
from mmdet3d.models import rotated_iou_3d_loss
from mmdet3d.registry import MODELS
def diff_diou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor:
"""Calculate differentiable DIoU of rotated 3d boxes.
Args:
box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha).
box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha).
Returns:
Tensor: (B, N) DIoU.
"""
box1 = box3d1[..., [0, 1, 3, 4, 6]]
box2 = box3d2[..., [0, 1, 3, 4, 6]]
corners1 = box2corners(box1)
corners2 = box2corners(box2)
intersection, _ = oriented_box_intersection_2d(corners1, corners2)
zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5
zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5
zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5
zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5
z_overlap = (torch.min(zmax1, zmax2) -
torch.max(zmin1, zmin2)).clamp_(min=0.)
intersection_3d = intersection * z_overlap
volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5]
volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5]
union_3d = volume1 + volume2 - intersection_3d
x1_max = torch.max(corners1[..., 0], dim=2)[0]
x1_min = torch.min(corners1[..., 0], dim=2)[0]
y1_max = torch.max(corners1[..., 1], dim=2)[0]
y1_min = torch.min(corners1[..., 1], dim=2)[0]
x2_max = torch.max(corners2[..., 0], dim=2)[0]
x2_min = torch.min(corners2[..., 0], dim=2)[0]
y2_max = torch.max(corners2[..., 1], dim=2)[0]
y2_min = torch.min(corners2[..., 1], dim=2)[0]
x_max = torch.max(x1_max, x2_max)
x_min = torch.min(x1_min, x2_min)
y_max = torch.max(y1_max, y2_max)
y_min = torch.min(y1_min, y2_min)
z_max = torch.max(zmax1, zmax2)
z_min = torch.min(zmin1, zmin2)
r2 = ((box1[..., :3] - box2[..., :3])**2).sum(dim=-1)
c2 = (x_min - x_max)**2 + (y_min - y_max)**2 + (z_min - z_max)**2
return intersection_3d / union_3d - r2 / c2
@weighted_loss
def rotated_diou_3d_loss(pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the DIoU loss (1-DIoU) of two sets of rotated bounding boxes.
Note that predictions and targets are one-to-one corresponded.
Args:
pred (torch.Tensor): Bbox predictions with shape [N, 7]
(x, y, z, w, l, h, alpha).
target (torch.Tensor): Bbox targets (gt) with shape [N, 7]
(x, y, z, w, l, h, alpha).
Returns:
torch.Tensor: IoU loss between predictions and targets.
"""
diou_loss = 1 - diff_diou_rotated_3d(
pred.unsqueeze(0), target.unsqueeze(0))[0]
return diou_loss
@MODELS.register_module()
class TR3DRotatedIoU3DLoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of rotated bounding boxes. The only
difference with original RotatedIoU3DLoss is the addition of DIoU mode.
These classes should be merged in the future.
Args:
mode (str): 'iou' for intersection over union or 'diou' for
distance-iou loss. Defaults to 'iou'.
reduction (str): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to 'mean'.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
mode: str = 'iou',
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(TR3DRotatedIoU3DLoss, self).__init__()
assert mode in ['iou', 'diou']
self.loss = rotated_iou_3d_loss if mode == 'iou' \
else rotated_diou_3d_loss
assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[str] = None,
**kwargs) -> Tensor:
"""Forward function of loss calculation.
Args:
pred (Tensor): Bbox predictions with shape [..., 7]
(x, y, z, w, l, h, alpha).
target (Tensor): Bbox targets (gt) with shape [..., 7]
(x, y, z, w, l, h, alpha).
weight (Tensor, optional): Weight of loss.
Defaults to None.
avg_factor (float, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
Returns:
Tensor: IoU loss between predictions and targets.
"""
if weight is not None and not torch.any(weight > 0):
return pred.sum() * weight.sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
weight = weight.mean(-1)
loss = self.loss_weight * self.loss(
pred,
target,
weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/SamsungLabs/tr3d/blob/master/mmdet3d/models/dense_heads/tr3d_head.py # noqa
from typing import List, Optional, Tuple
try:
import MinkowskiEngine as ME
from MinkowskiEngine import SparseTensor
except ImportError:
# Please follow getting_started.md to install MinkowskiEngine.
ME = SparseTensor = None
pass
import torch
from mmcv.ops import nms3d, nms3d_normal
from mmengine.model import bias_init_with_prob
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.models import Base3DDenseHead
from mmdet3d.registry import MODELS
from mmdet3d.structures import BaseInstance3DBoxes
from mmdet3d.utils import InstanceList, OptInstanceList
@MODELS.register_module()
class TR3DHead(Base3DDenseHead):
r"""Bbox head of `TR3D <https://arxiv.org/abs/2302.02858>`_.
Args:
in_channels (int): Number of channels in input tensors.
num_reg_outs (int): Number of regression layer channels.
voxel_size (float): Voxel size in meters.
pts_center_threshold (int): Box to location assigner parameter.
After feature level for the box is determined, assigner selects
pts_center_threshold locations closest to the box center.
bbox_loss (dict): Config of bbox loss. Defaults to
dict(type='AxisAlignedIoULoss', mode='diou', reduction=None).
cls_loss (dict): Config of classification loss. Defaults to
dict = dict(type='mmdet.FocalLoss', reduction=None).
train_cfg (dict, optional): Config for train stage. Defaults to None.
test_cfg (dict, optional): Config for test stage. Defaults to None.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int,
num_reg_outs: int,
voxel_size: int,
pts_center_threshold: int,
label2level: Tuple[int],
bbox_loss: dict = dict(
type='TR3DAxisAlignedIoULoss',
mode='diou',
reduction='none'),
cls_loss: dict = dict(
type='mmdet.FocalLoss', reduction='none'),
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super(TR3DHead, self).__init__(init_cfg)
if ME is None:
raise ImportError(
'Please follow `getting_started.md` to install MinkowskiEngine.`' # noqa: E501
)
self.voxel_size = voxel_size
self.pts_center_threshold = pts_center_threshold
self.label2level = label2level
self.bbox_loss = MODELS.build(bbox_loss)
self.cls_loss = MODELS.build(cls_loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers(len(self.label2level), in_channels, num_reg_outs)
def _init_layers(self, num_classes: int, in_channels: int,
num_reg_outs: int):
"""Initialize layers.
Args:
in_channels (int): Number of channels in input tensors.
num_reg_outs (int): Number of regression layer channels.
num_classes (int): Number of classes.
"""
self.conv_reg = ME.MinkowskiConvolution(
in_channels, num_reg_outs, kernel_size=1, bias=True, dimension=3)
self.conv_cls = ME.MinkowskiConvolution(
in_channels, num_classes, kernel_size=1, bias=True, dimension=3)
def init_weights(self):
"""Initialize weights."""
nn.init.normal_(self.conv_reg.kernel, std=.01)
nn.init.normal_(self.conv_cls.kernel, std=.01)
nn.init.constant_(self.conv_cls.bias, bias_init_with_prob(.01))
def _forward_single(self, x: SparseTensor) -> Tuple[Tensor, ...]:
"""Forward pass per level.
Args:
x (SparseTensor): Per level neck output tensor.
Returns:
tuple[Tensor]: Per level head predictions.
"""
reg_final = self.conv_reg(x).features
reg_distance = torch.exp(reg_final[:, 3:6])
reg_angle = reg_final[:, 6:]
bbox_pred = torch.cat((reg_final[:, :3], reg_distance, reg_angle),
dim=1)
cls_pred = self.conv_cls(x).features
bbox_preds, cls_preds, points = [], [], []
for permutation in x.decomposition_permutations:
bbox_preds.append(bbox_pred[permutation])
cls_preds.append(cls_pred[permutation])
points.append(x.coordinates[permutation][:, 1:] * self.voxel_size)
return bbox_preds, cls_preds, points
def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], ...]:
"""Forward pass.
Args:
x (list[Tensor]): Features from the backbone.
Returns:
Tuple[List[Tensor], ...]: Predictions of the head.
"""
bbox_preds, cls_preds, points = [], [], []
for i in range(len(x)):
bbox_pred, cls_pred, point = self._forward_single(x[i])
bbox_preds.append(bbox_pred)
cls_preds.append(cls_pred)
points.append(point)
return bbox_preds, cls_preds, points
def _loss_by_feat_single(self, bbox_preds: List[Tensor],
cls_preds: List[Tensor], points: List[Tensor],
gt_bboxes: BaseInstance3DBoxes, gt_labels: Tensor,
input_meta: dict) -> Tuple[Tensor, ...]:
"""Loss function of single sample.
Args:
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
points (list[Tensor]): Final location coordinates for all levels.
gt_bboxes (:obj:`BaseInstance3DBoxes`): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
input_meta (dict): Scene meta info.
Returns:
tuple[Tensor, ...]: Bbox and classification loss
values and a boolean mask of assigned points.
"""
num_classes = cls_preds[0].shape[1]
bbox_targets, cls_targets = self.get_targets(points, gt_bboxes,
gt_labels, num_classes)
bbox_preds = torch.cat(bbox_preds)
cls_preds = torch.cat(cls_preds)
points = torch.cat(points)
# cls loss
cls_loss = self.cls_loss(cls_preds, cls_targets)
# bbox loss
pos_mask = cls_targets < num_classes
pos_bbox_preds = bbox_preds[pos_mask]
if pos_mask.sum() > 0:
pos_points = points[pos_mask]
pos_bbox_preds = bbox_preds[pos_mask]
pos_bbox_targets = bbox_targets[pos_mask]
bbox_loss = self.bbox_loss(
self._bbox_to_loss(
self._bbox_pred_to_bbox(pos_points, pos_bbox_preds)),
self._bbox_to_loss(pos_bbox_targets))
else:
bbox_loss = pos_bbox_preds
return bbox_loss, cls_loss, pos_mask
def loss_by_feat(self,
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]],
points: List[List[Tensor]],
batch_gt_instances_3d: InstanceList,
batch_input_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
**kwargs) -> dict:
"""Loss function about feature.
Args:
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
points (list[list[Tensor]]): Final location coordinates for all
scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes_3d``、`
`labels_3d``、``depths``、``centers_2d`` and attributes.
batch_input_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
Returns:
dict: Bbox, and classification losses.
"""
bbox_losses, cls_losses, pos_masks = [], [], []
for i in range(len(batch_input_metas)):
bbox_loss, cls_loss, pos_mask = self._loss_by_feat_single(
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
points=[x[i] for x in points],
input_meta=batch_input_metas[i],
gt_bboxes=batch_gt_instances_3d[i].bboxes_3d,
gt_labels=batch_gt_instances_3d[i].labels_3d)
if len(bbox_loss) > 0:
bbox_losses.append(bbox_loss)
cls_losses.append(cls_loss)
pos_masks.append(pos_mask)
return dict(
bbox_loss=torch.mean(torch.cat(bbox_losses)),
cls_loss=torch.sum(torch.cat(cls_losses)) /
torch.sum(torch.cat(pos_masks)))
def _predict_by_feat_single(self, bbox_preds: List[Tensor],
cls_preds: List[Tensor], points: List[Tensor],
input_meta: dict) -> InstanceData:
"""Generate boxes for single sample.
Args:
center_preds (list[Tensor]): Centerness predictions for all levels.
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
points (list[Tensor]): Final location coordinates for all levels.
input_meta (dict): Scene meta info.
Returns:
InstanceData: Predicted bounding boxes, scores and labels.
"""
scores = torch.cat(cls_preds).sigmoid()
bbox_preds = torch.cat(bbox_preds)
points = torch.cat(points)
max_scores, _ = scores.max(dim=1)
if len(scores) > self.test_cfg.nms_pre > 0:
_, ids = max_scores.topk(self.test_cfg.nms_pre)
bbox_preds = bbox_preds[ids]
scores = scores[ids]
points = points[ids]
bboxes = self._bbox_pred_to_bbox(points, bbox_preds)
bboxes, scores, labels = self._single_scene_multiclass_nms(
bboxes, scores, input_meta)
bboxes = input_meta['box_type_3d'](
bboxes,
box_dim=bboxes.shape[1],
with_yaw=bboxes.shape[1] == 7,
origin=(.5, .5, .5))
results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
return results
def predict_by_feat(self, bbox_preds: List[List[Tensor]], cls_preds,
points: List[List[Tensor]],
batch_input_metas: List[dict],
**kwargs) -> List[InstanceData]:
"""Generate boxes for all scenes.
Args:
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes.
points (list[list[Tensor]]): Final location coordinates for all
scenes.
batch_input_metas (list[dict]): Meta infos for all scenes.
Returns:
list[InstanceData]: Predicted bboxes, scores, and labels for
all scenes.
"""
results = []
for i in range(len(batch_input_metas)):
result = self._predict_by_feat_single(
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
points=[x[i] for x in points],
input_meta=batch_input_metas[i])
results.append(result)
return results
@staticmethod
def _bbox_to_loss(bbox):
"""Transform box to the axis-aligned or rotated iou loss format.
Args:
bbox (Tensor): 3D box of shape (N, 6) or (N, 7).
Returns:
Tensor: Transformed 3D box of shape (N, 6) or (N, 7).
"""
# rotated iou loss accepts (x, y, z, w, h, l, heading)
if bbox.shape[-1] != 6:
return bbox
# axis-aligned case: x, y, z, w, h, l -> x1, y1, z1, x2, y2, z2
return torch.stack(
(bbox[..., 0] - bbox[..., 3] / 2, bbox[..., 1] - bbox[..., 4] / 2,
bbox[..., 2] - bbox[..., 5] / 2, bbox[..., 0] + bbox[..., 3] / 2,
bbox[..., 1] + bbox[..., 4] / 2, bbox[..., 2] + bbox[..., 5] / 2),
dim=-1)
@staticmethod
def _bbox_pred_to_bbox(points, bbox_pred):
"""Transform predicted bbox parameters to bbox.
Args:
points (Tensor): Final locations of shape (N, 3)
bbox_pred (Tensor): Predicted bbox parameters of shape (N, 6)
or (N, 8).
Returns:
Tensor: Transformed 3D box of shape (N, 6) or (N, 7).
"""
if bbox_pred.shape[0] == 0:
return bbox_pred
x_center = points[:, 0] + bbox_pred[:, 0]
y_center = points[:, 1] + bbox_pred[:, 1]
z_center = points[:, 2] + bbox_pred[:, 2]
base_bbox = torch.stack([
x_center, y_center, z_center, bbox_pred[:, 3], bbox_pred[:, 4],
bbox_pred[:, 5]
], -1)
# axis-aligned case
if bbox_pred.shape[1] == 6:
return base_bbox
# rotated case: ..., sin(2a)ln(q), cos(2a)ln(q)
scale = bbox_pred[:, 3] + bbox_pred[:, 4]
q = torch.exp(
torch.sqrt(
torch.pow(bbox_pred[:, 6], 2) + torch.pow(bbox_pred[:, 7], 2)))
alpha = 0.5 * torch.atan2(bbox_pred[:, 6], bbox_pred[:, 7])
return torch.stack(
(x_center, y_center, z_center, scale / (1 + q), scale /
(1 + q) * q, bbox_pred[:, 5] + bbox_pred[:, 4], alpha),
dim=-1)
@torch.no_grad()
def get_targets(self, points: Tensor, gt_bboxes: BaseInstance3DBoxes,
gt_labels: Tensor, num_classes: int) -> Tuple[Tensor, ...]:
"""Compute targets for final locations for a single scene.
Args:
points (list[Tensor]): Final locations for all levels.
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
num_classes (int): Number of classes.
Returns:
tuple[Tensor, ...]: Bbox and classification targets for all
locations.
"""
float_max = points[0].new_tensor(1e8)
levels = torch.cat([
points[i].new_tensor(i, dtype=torch.long).expand(len(points[i]))
for i in range(len(points))
])
points = torch.cat(points)
n_points = len(points)
n_boxes = len(gt_bboxes)
if len(gt_labels) == 0:
return points.new_tensor([]), \
gt_labels.new_full((n_points,), num_classes)
boxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
dim=1)
boxes = boxes.to(points.device).expand(n_points, n_boxes, 7)
points = points.unsqueeze(1).expand(n_points, n_boxes, 3)
# condition 1: fix level for label
label2level = gt_labels.new_tensor(self.label2level)
label_levels = label2level[gt_labels].unsqueeze(0).expand(
n_points, n_boxes)
point_levels = torch.unsqueeze(levels, 1).expand(n_points, n_boxes)
level_condition = label_levels == point_levels
# condition 2: keep topk location per box by center distance
center = boxes[..., :3]
center_distances = torch.sum(torch.pow(center - points, 2), dim=-1)
center_distances = torch.where(level_condition, center_distances,
float_max)
topk_distances = torch.topk(
center_distances,
min(self.pts_center_threshold + 1, len(center_distances)),
largest=False,
dim=0).values[-1]
topk_condition = center_distances < topk_distances.unsqueeze(0)
# condition 3: min center distance to box per point
center_distances = torch.where(topk_condition, center_distances,
float_max)
min_values, min_ids = center_distances.min(dim=1)
min_inds = torch.where(min_values < float_max, min_ids, -1)
bbox_targets = boxes[0][min_inds]
if not gt_bboxes.with_yaw:
bbox_targets = bbox_targets[:, :-1]
cls_targets = torch.where(min_inds >= 0, gt_labels[min_inds],
num_classes)
return bbox_targets, cls_targets
def _single_scene_multiclass_nms(self, bboxes: Tensor, scores: Tensor,
input_meta: dict) -> Tuple[Tensor, ...]:
"""Multi-class nms for a single scene.
Args:
bboxes (Tensor): Predicted boxes of shape (N_boxes, 6) or
(N_boxes, 7).
scores (Tensor): Predicted scores of shape (N_boxes, N_classes).
input_meta (dict): Scene meta data.
Returns:
tuple[Tensor, ...]: Predicted bboxes, scores and labels.
"""
num_classes = scores.shape[1]
with_yaw = bboxes.shape[1] == 7
nms_bboxes, nms_scores, nms_labels = [], [], []
for i in range(num_classes):
ids = scores[:, i] > self.test_cfg.score_thr
if not ids.any():
continue
class_scores = scores[ids, i]
class_bboxes = bboxes[ids]
if with_yaw:
nms_function = nms3d
else:
class_bboxes = torch.cat(
(class_bboxes, torch.zeros_like(class_bboxes[:, :1])),
dim=1)
nms_function = nms3d_normal
nms_ids = nms_function(class_bboxes, class_scores,
self.test_cfg.iou_thr)
nms_bboxes.append(class_bboxes[nms_ids])
nms_scores.append(class_scores[nms_ids])
nms_labels.append(
bboxes.new_full(
class_scores[nms_ids].shape, i, dtype=torch.long))
if len(nms_bboxes):
nms_bboxes = torch.cat(nms_bboxes, dim=0)
nms_scores = torch.cat(nms_scores, dim=0)
nms_labels = torch.cat(nms_labels, dim=0)
else:
nms_bboxes = bboxes.new_zeros((0, bboxes.shape[1]))
nms_scores = bboxes.new_zeros((0, ))
nms_labels = bboxes.new_zeros((0, ))
if not with_yaw:
nms_bboxes = nms_bboxes[:, :6]
return nms_bboxes, nms_scores, nms_labels
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/SamsungLabs/tr3d/blob/master/mmdet3d/models/necks/tr3d_neck.py # noqa
from typing import List, Tuple
try:
import MinkowskiEngine as ME
from MinkowskiEngine import SparseTensor
except ImportError:
# Please follow getting_started.md to install MinkowskiEngine.
ME = SparseTensor = None
pass
from mmengine.model import BaseModule
from torch import nn
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TR3DNeck(BaseModule):
r"""Neck of `TR3D <https://arxiv.org/abs/2302.02858>`_.
Args:
in_channels (tuple[int]): Number of channels in input tensors.
out_channels (int): Number of channels in output tensors.
"""
def __init__(self, in_channels: Tuple[int], out_channels: int):
super(TR3DNeck, self).__init__()
self._init_layers(in_channels[1:], out_channels)
def _init_layers(self, in_channels: Tuple[int], out_channels: int):
"""Initialize layers.
Args:
in_channels (tuple[int]): Number of channels in input tensors.
out_channels (int): Number of channels in output tensors.
"""
for i in range(len(in_channels)):
if i > 0:
self.add_module(
f'up_block_{i}',
self._make_block(in_channels[i], in_channels[i - 1], True,
2))
if i < len(in_channels) - 1:
self.add_module(
f'lateral_block_{i}',
self._make_block(in_channels[i], in_channels[i]))
self.add_module(f'out_block_{i}',
self._make_block(in_channels[i], out_channels))
def init_weights(self):
"""Initialize weights."""
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(
m.kernel, mode='fan_out', nonlinearity='relu')
if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
def forward(self, x: List[SparseTensor]) -> List[SparseTensor]:
"""Forward pass.
Args:
x (list[SparseTensor]): Features from the backbone.
Returns:
List[Tensor]: Output features from the neck.
"""
x = x[1:]
outs = []
inputs = x
x = inputs[-1]
for i in range(len(inputs) - 1, -1, -1):
if i < len(inputs) - 1:
x = self.__getattr__(f'up_block_{i + 1}')(x)
x = inputs[i] + x
x = self.__getattr__(f'lateral_block_{i}')(x)
out = self.__getattr__(f'out_block_{i}')(x)
outs.append(out)
return outs[::-1]
@staticmethod
def _make_block(in_channels: int,
out_channels: int,
generative: bool = False,
stride: int = 1) -> nn.Module:
"""Construct Conv-Norm-Act block.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
generative (bool): Use generative convolution if True.
Defaults to False.
stride (int): Stride of the convolution. Defaults to 1.
Returns:
torch.nn.Module: With corresponding layers.
"""
conv = ME.MinkowskiGenerativeConvolutionTranspose if generative \
else ME.MinkowskiConvolution
return nn.Sequential(
conv(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
dimension=3), ME.MinkowskiBatchNorm(out_channels),
ME.MinkowskiReLU(inplace=True))
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import numpy as np
from mmdet3d.datasets import PointSample
from mmdet3d.registry import TRANSFORMS
from mmdet3d.structures.points import BasePoints
@TRANSFORMS.register_module()
class TR3DPointSample(PointSample):
"""The only difference with PointSample is the support of float num_points
parameter.
In this case we sample random fraction of points from num_points to 100%
points. These classes should be merged in the future.
"""
def _points_random_sampling(
self,
points: BasePoints,
num_samples: Union[int, float],
sample_range: Optional[float] = None,
replace: bool = False,
return_choices: bool = False
) -> Union[Tuple[BasePoints, np.ndarray], BasePoints]:
"""Points random sampling.
Sample points to a certain number.
Args:
points (:obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None.
replace (bool): Sampling with or without replacement.
Defaults to False.
return_choices (bool): Whether return choice. Defaults to False.
Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
- points (:obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples.
"""
if isinstance(num_samples, float):
assert num_samples < 1
num_samples = int(
np.random.uniform(self.num_points, 1.) * points.shape[0])
if not replace:
replace = (points.shape[0] < num_samples)
point_range = range(len(points))
if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
dist = np.linalg.norm(points.coord.numpy(), axis=1)
far_inds = np.where(dist >= sample_range)[0]
near_inds = np.where(dist < sample_range)[0]
# in case there are too many far points
if len(far_inds) > num_samples:
far_inds = np.random.choice(
far_inds, num_samples, replace=False)
point_range = near_inds
num_samples -= len(far_inds)
choices = np.random.choice(point_range, num_samples, replace=replace)
if sample_range is not None and not replace:
choices = np.concatenate((far_inds, choices))
# Shuffle points after sampling
np.random.shuffle(choices)
if return_choices:
return points[choices], choices
else:
return points[choices]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment