Unverified Commit 8fb2cf6c authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Feature] Support inference of DSVT in `projects` (#2606)

* support inference

* align inference precision

* add readme

* polish docs

* polish docs
parent 456b7403
......@@ -579,6 +579,8 @@ class LoadPointsFromFile(BaseTransform):
use_color (bool): Whether to use color features. Defaults to False.
norm_intensity (bool): Whether to normlize the intensity. Defaults to
False.
norm_elongation (bool): Whether to normlize the elongation. This is
usually used in Waymo dataset.Defaults to False.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
......@@ -590,6 +592,7 @@ class LoadPointsFromFile(BaseTransform):
shift_height: bool = False,
use_color: bool = False,
norm_intensity: bool = False,
norm_elongation: bool = False,
backend_args: Optional[dict] = None) -> None:
self.shift_height = shift_height
self.use_color = use_color
......@@ -603,6 +606,7 @@ class LoadPointsFromFile(BaseTransform):
self.load_dim = load_dim
self.use_dim = use_dim
self.norm_intensity = norm_intensity
self.norm_elongation = norm_elongation
self.backend_args = backend_args
def _load_points(self, pts_filename: str) -> np.ndarray:
......@@ -646,6 +650,10 @@ class LoadPointsFromFile(BaseTransform):
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
if self.norm_elongation:
assert len(self.use_dim) >= 5, \
f'When using elongation norm, expect used dimensions >= 5, got {len(self.use_dim)}' # noqa: E501
points[:, 4] = np.tanh(points[:, 4])
attribute_dims = None
if self.shift_height:
......@@ -682,6 +690,8 @@ class LoadPointsFromFile(BaseTransform):
repr_str += f'backend_args={self.backend_args}, '
repr_str += f'load_dim={self.load_dim}, '
repr_str += f'use_dim={self.use_dim})'
repr_str += f'norm_intensity={self.norm_intensity})'
repr_str += f'norm_elongation={self.norm_elongation})'
return repr_str
......
......@@ -74,7 +74,8 @@ class SECONDFPN(BaseModule):
"""Forward function.
Args:
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
x (List[torch.Tensor]): Multi-level features with 4D Tensor in
(N, C, H, W) shape.
Returns:
list[torch.Tensor]: Multi-level feature maps.
......
# DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets
> [DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets](https://arxiv.org/abs/2301.06051)
<!-- [ALGORITHM] -->
## Abstract
Designing an efficient yet deployment-friendly 3D backbone to handle sparse point clouds is a fundamental problem
in 3D perception. Compared with the customized sparse
convolution, the attention mechanism in Transformers is
more appropriate for flexibly modeling long-range relationships and is easier to be deployed in real-world applications.
However, due to the sparse characteristics of point clouds,
it is non-trivial to apply a standard transformer on sparse
points. In this paper, we present Dynamic Sparse Voxel
Transformer (DSVT), a single-stride window-based voxel
Transformer backbone for outdoor 3D perception. In order
to efficiently process sparse points in parallel, we propose
Dynamic Sparse Window Attention, which partitions a series
of local regions in each window according to its sparsity
and then computes the features of all regions in a fully parallel manner. To allow the cross-set connection, we design
a rotated set partitioning strategy that alternates between
two partitioning configurations in consecutive self-attention
layers. To support effective downsampling and better encode geometric information, we also propose an attentionstyle 3D pooling module on sparse points, which is powerful
and deployment-friendly without utilizing any customized
CUDA operations. Our model achieves state-of-the-art performance with a broad range of 3D perception tasks. More
importantly, DSVT can be easily deployed by TensorRT with
real-time inference speed (27Hz). Code will be available at
https://github.com/Haiyang-W/DSVT.
<div align=center>
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/34888372/245692705-e61be20c-2a7d-4ab9-85e3-b36f662c1bdf.png" width="800"/>
</div>
## Introduction
We implement DSVT and provide the results on Waymo dataset.
## Usage
<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->
### Installation
```shell
pip install torch_scatter==2.0.9
python projects/DSVT/setup.py develop # compile `ingroup_inds_op` cuda operation
```
### Testing commands
In MMDetection3D's root directory, run the following command to test the model:
```bash
python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py ${CHECKPOINT_PATH}
```
### Training commands
The support of training DSVT is on the way.
## Results and models
### Waymo
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
## Citation
```latex
@inproceedings{wang2023dsvt,
title={DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets},
author={Haiyang Wang, Chen Shi, Shaoshuai Shi, Meng Lei, Sen Wang, Di He, Bernt Schiele and Liwei Wang},
booktitle={CVPR},
year={2023}
}
```
_base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(
imports=['projects.DSVT.dsvt'], allow_failed_imports=False)
voxel_size = [0.32, 0.32, 6]
grid_size = [468, 468, 1]
point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4.0]
data_root = 'data/waymo/kitti_format/'
class_names = ['Car', 'Pedestrian', 'Cyclist']
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=False)
backend_args = None
model = dict(
type='DSVT',
data_preprocessor=dict(type='Det3DDataPreprocessor', voxel=False),
voxel_encoder=dict(
type='DynamicPillarVFE3D',
with_distance=False,
use_absolute_xyz=True,
use_norm=True,
num_filters=[192, 192],
num_point_features=5,
voxel_size=voxel_size,
grid_size=grid_size,
point_cloud_range=point_cloud_range),
middle_encoder=dict(
type='DSVTMiddleEncoder',
input_layer=dict(
sparse_shape=grid_size,
downsample_stride=[],
dim_model=[192],
set_info=[[36, 4]],
window_shape=[[12, 12, 1]],
hybrid_factor=[2, 2, 1], # x, y, z
shift_list=[[[0, 0, 0], [6, 6, 0]]],
normalize_pos=False),
set_info=[[36, 4]],
dim_model=[192],
dim_feedforward=[384],
stage_num=1,
nhead=[8],
conv_out_channel=192,
output_shape=[468, 468],
dropout=0.,
activation='gelu'),
map2bev=dict(
type='PointPillarsScatter3D',
output_shape=grid_size,
num_bev_feats=192),
backbone=dict(
type='ResSECOND',
in_channels=192,
out_channels=[128, 128, 256],
blocks_nums=[1, 2, 2],
layer_strides=[1, 2, 2]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 128, 256],
out_channels=[128, 128, 128],
upsample_strides=[1, 2, 4],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=False),
bbox_head=dict(
type='DSVTCenterHead',
in_channels=sum([128, 128, 128]),
tasks=[dict(num_class=3, class_names=class_names)],
common_heads=dict(
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), iou=(1, 2)),
share_conv_channel=64,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01),
bbox_coder=dict(
type='DSVTBBoxCoder',
pc_range=point_cloud_range,
max_num=500,
post_center_range=[-80, -80, -10.0, 80, 80, 10.0],
score_threshold=0.1,
out_size_factor=1,
voxel_size=voxel_size[:2],
code_size=7),
separate_head=dict(
type='SeparateHead',
init_bias=-2.19,
final_kernel=3,
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01)),
loss_cls=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=grid_size,
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
test_cfg=dict(
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
iou_rectifier=[[0.68, 0.71, 0.65]],
pc_range=[-80, -80],
out_size_factor=4,
voxel_size=voxel_size[:2],
nms_type='rotate',
multi_class_nms=True,
pre_max_size=[[4096, 4096, 4096]],
post_max_size=[[500, 500, 500]],
nms_thr=[[0.7, 0.6, 0.55]]))
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'waymo_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4],
backend_args=backend_args),
backend_args=backend_args)
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True,
backend_args=backend_args),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
dataset_type = 'WaymoDataset'
val_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
ann_file='waymo_infos_val.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
metainfo=metainfo,
box_type_3d='LiDAR',
backend_args=backend_args))
test_dataloader = val_dataloader
val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# runtime settings
val_cfg = dict()
test_cfg = dict()
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (1 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=8)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
from .dsvt import DSVT
from .dsvt_head import DSVTCenterHead
from .dsvt_transformer import DSVTMiddleEncoder
from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND
from .utils import DSVTBBoxCoder
__all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
]
from typing import Dict, List, Optional
import torch
from torch import Tensor
from mmdet3d.models import Base3DDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
@MODELS.register_module()
class DSVT(Base3DDetector):
"""DSVT detector."""
def __init__(self,
voxel_encoder: Optional[dict] = None,
middle_encoder: Optional[dict] = None,
backbone: Optional[dict] = None,
neck: Optional[dict] = None,
map2bev: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
**kwargs):
super(DSVT, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor, **kwargs)
if voxel_encoder:
self.voxel_encoder = MODELS.build(voxel_encoder)
if middle_encoder:
self.middle_encoder = MODELS.build(middle_encoder)
if backbone:
self.backbone = MODELS.build(backbone)
self.map2bev = MODELS.build(map2bev)
if neck is not None:
self.neck = MODELS.build(neck)
if bbox_head:
bbox_head.update(train_cfg=train_cfg, test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
@property
def with_bbox(self):
"""bool: Whether the detector has a 3D box head."""
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_backbone(self):
"""bool: Whether the detector has a 3D backbone."""
return hasattr(self, 'backbone') and self.backbone is not None
@property
def with_voxel_encoder(self):
"""bool: Whether the detector has a voxel encoder."""
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
"""bool: Whether the detector has a middle encoder."""
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
def _forward(self):
pass
def extract_feat(self, batch_inputs_dict: dict) -> tuple:
"""Extract features from images and points.
Args:
batch_inputs_dict (dict): Dict of batch inputs. It
contains
- points (List[tensor]): Point cloud of multiple inputs.
- imgs (tensor): Image tensor with shape (B, C, H, W).
Returns:
tuple: Two elements in tuple arrange as
image features and point cloud features.
"""
batch_out_dict = self.voxel_encoder(batch_inputs_dict)
batch_out_dict = self.middle_encoder(batch_out_dict)
batch_out_dict = self.map2bev(batch_out_dict)
multi_feats = self.backbone(batch_out_dict['spatial_features'])
feats = self.neck(multi_feats)
return feats
def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' and `imgs` keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Tensor of batch images, has shape
(B, C, H ,W)
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
pass
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
pts_feats = self.extract_feat(batch_inputs_dict)
results_list_3d = self.bbox_head.predict(pts_feats, batch_data_samples)
detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return detsamples
from typing import Dict, List, Tuple
import torch
from mmdet.models.utils import multi_apply
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet3d.models import CenterHead
from mmdet3d.models.layers import circle_nms, nms_bev
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample, xywhr2xyxyr
@MODELS.register_module()
class DSVTCenterHead(CenterHead):
"""CenterHead for DSVT.
This head adds IoU prediction branch based on the original CenterHead.
"""
def __init__(self, *args, **kwargs):
super(DSVTCenterHead, self).__init__(*args, **kwargs)
def forward_single(self, x: Tensor) -> dict:
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x)
for task in self.task_heads:
ret_dicts.append(task(x))
return ret_dicts
def forward(self, feats: List[Tensor]) -> Tuple[List[Tensor]]:
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results for tasks.
"""
return multi_apply(self.forward_single, feats)
def loss(self, pts_feats: List[Tensor],
batch_data_samples: List[Det3DDataSample], *args,
**kwargs) -> Dict[str, Tensor]:
"""Forward function of training.
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict: Losses of each branch.
"""
pass
def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args,
**kwargs):
"""Loss function for CenterHead.
Args:
preds_dicts (tuple[list[dict]]): Prediction results of
multiple tasks. The outer tuple indicate different
tasks head, and the internal list indicate different
FPN level.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
``labels_3d`` attributes.
Returns:
dict[str,torch.Tensor]: Loss of heatmap and bbox of each task.
"""
pass
def predict(self,
pts_feats: Tuple[torch.Tensor],
batch_data_samples: List[Det3DDataSample],
rescale=True,
**kwargs) -> List[InstanceData]:
"""
Args:
pts_feats (Tuple[torch.Tensor]): Point features..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to
the original scale.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
preds_dict = self(pts_feats)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
preds_dict, batch_input_metas, rescale=rescale, **kwargs)
return results_list
def predict_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_input_metas: List[dict], *args,
**kwargs) -> List[InstanceData]:
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results of
multiple tasks. The outer tuple indicate different
tasks head, and the internal list indicate different
FPN level.
batch_input_metas (list[dict]): Meta info of multiple
inputs.
Returns:
list[:obj:`InstanceData`]: Instance prediction
results of each sample after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`LiDARInstance3DBoxes`): Prediction
of bboxes, contains a tensor with shape
(num_instances, 7) or (num_instances, 9), and
the last 2 dimensions of 9 is
velocity.
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
num_class_with_bg = self.num_classes[task_id]
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid()
batch_reg = preds_dict[0]['reg']
batch_hei = preds_dict[0]['height']
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim'])
else:
batch_dim = preds_dict[0]['dim']
batch_rotc = preds_dict[0]['rot'][:, 0].unsqueeze(1)
batch_rots = preds_dict[0]['rot'][:, 1].unsqueeze(1)
batch_iou = (preds_dict[0]['iou'] +
1) * 0.5 if 'iou' in preds_dict[0] else None
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel']
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
iou=batch_iou)
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
batch_reg_preds, batch_cls_preds, batch_cls_labels, batch_iou_preds = [], [], [], [] # noqa: E501
for box in temp:
batch_reg_preds.append(box['bboxes'])
batch_cls_preds.append(box['scores'])
batch_cls_labels.append(box['labels'].long())
batch_iou_preds.append(box['iou'])
if self.test_cfg['nms_type'] == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(task_id, num_class_with_bg,
batch_cls_preds, batch_reg_preds,
batch_iou_preds, batch_cls_labels,
batch_input_metas))
# Merge branches results
num_samples = len(rets[0])
ret_list = []
for i in range(num_samples):
temp_instances = InstanceData()
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = batch_input_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets])
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
temp_instances.bboxes_3d = bboxes
temp_instances.scores_3d = scores
temp_instances.labels_3d = labels
ret_list.append(temp_instances)
return ret_list
def get_task_detections(self, task_id, num_class_with_bg, batch_cls_preds,
batch_reg_preds, batch_iou_preds, batch_cls_labels,
img_metas):
"""Rotate nms for each task.
Args:
num_class_with_bg (int): Number of classes for the current task.
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N].
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9].
batch_iou_preds (list[torch.Tensor]): Prediction IoU with the
shape of [N].
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N].
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
"""
predictions_dicts = []
for i, (box_preds, cls_preds, iou_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_iou_preds,
batch_cls_labels)):
pred_iou = torch.clamp(iou_preds, min=0, max=1.0)
iou_rectifier = pred_iou.new_tensor(
self.test_cfg['iou_rectifier'][task_id])
cls_preds = torch.pow(cls_preds,
1 - iou_rectifier[cls_labels]) * torch.pow(
pred_iou, iou_rectifier[cls_labels])
# Apply NMS in bird eye view
# get the highest score per prediction, then apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = cls_preds
top_labels = torch.zeros(
cls_preds.shape[0],
device=cls_preds.device,
dtype=torch.long)
else:
top_labels = cls_labels.long()
top_scores = cls_preds
if top_scores.shape[0] != 0:
boxes_for_nms = xywhr2xyxyr(img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev)
pre_max_size = self.test_cfg['pre_max_size'][task_id]
post_max_size = self.test_cfg['post_max_size'][task_id]
# cls_label_per_task = self.cls_id_mapping_per_task[task_id]
all_selected_mask = torch.zeros_like(top_labels, dtype=bool)
all_indices = torch.arange(top_labels.size(0)).to(
top_labels.device)
# Mind this when training on the new coordinate
# Transform to old mmdet3d coordinate
boxes_for_nms[:, 4] = (-boxes_for_nms[:, 4] + torch.pi / 2 * 1)
boxes_for_nms[:, 4] = (boxes_for_nms[:, 4] +
torch.pi) % (2 * torch.pi) - torch.pi
for i, nms_thr in enumerate(self.test_cfg['nms_thr'][task_id]):
label_mask = top_labels == i
selected = nms_bev(
boxes_for_nms[label_mask],
top_scores[label_mask],
thresh=nms_thr,
pre_max_size=pre_max_size[i],
post_max_size=post_max_size[i])
indices = all_indices[label_mask][selected]
all_selected_mask.scatter_(0, indices, True)
else:
all_selected_mask = []
# if selected is not None:
selected_boxes = box_preds[all_selected_mask]
selected_labels = top_labels[all_selected_mask]
selected_scores = top_scores[all_selected_mask]
# finally generate predictions.
if selected_boxes.shape[0] != 0:
box_preds = selected_boxes
scores = selected_scores
label_preds = selected_labels
final_box_preds = box_preds
final_scores = scores
final_labels = label_preds
predictions_dict = dict(
bboxes=final_box_preds,
scores=final_scores,
labels=final_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts
# modified from https://github.com/Haiyang-W/DSVT
from math import ceil
import torch
from torch import nn
from .utils import (PositionEmbeddingLearned, get_continous_inds,
get_inner_win_inds_cuda, get_pooling_index,
get_window_coors)
class DSVTInputLayer(nn.Module):
'''
This class converts the output of vfe to dsvt input.
We do in this class:
1. Window partition: partition voxels to non-overlapping windows.
2. Set partition: generate non-overlapped and size-equivalent local sets
within each window.
3. Pre-compute the downsample information between two consecutive stages.
4. Pre-compute the position embedding vectors.
Args:
sparse_shape (tuple[int, int, int]): Shape of input space
(xdim, ydim, zdim).
window_shape (list[list[int, int, int]]): Window shapes
(winx, winy, winz) in different stages. Length: stage_num.
downsample_stride (list[list[int, int, int]]): Downsample
strides between two consecutive stages.
Element i is [ds_x, ds_y, ds_z], which is used between stage_i and
stage_{i+1}. Length: stage_num - 1.
dim_model (list[int]): Number of input channels for each stage. Length:
stage_num.
set_info (list[list[int, int]]): A list of set config for each stage.
Eelement i contains
[set_size, block_num], where set_size is the number of voxel in a
set and block_num is the
number of blocks for stage i. Length: stage_num.
hybrid_factor (list[int, int, int]): Control the window shape in
different blocks.
e.g. for block_{0} and block_{1} in stage_0, window shapes are
[win_x, win_y, win_z] and
[win_x * h[0], win_y * h[1], win_z * h[2]] respectively.
shift_list (list): Shift window. Length: stage_num.
normalize_pos (bool): Whether to normalize coordinates in position
embedding.
'''
def __init__(self, sparse_shape, window_shape, downsample_stride,
dim_model, set_info, hybrid_factor, shift_list,
normalize_pos):
super().__init__()
self.sparse_shape = sparse_shape
self.window_shape = window_shape
self.downsample_stride = downsample_stride
self.dim_model = dim_model
self.set_info = set_info
self.stage_num = len(self.dim_model)
self.hybrid_factor = hybrid_factor
self.window_shape = [[
self.window_shape[s_id],
[
self.window_shape[s_id][coord_id] *
self.hybrid_factor[coord_id] for coord_id in range(3)
]
] for s_id in range(self.stage_num)]
self.shift_list = shift_list
self.normalize_pos = normalize_pos
self.num_shifts = [
2,
] * len(self.window_shape)
self.sparse_shape_list = [self.sparse_shape]
# compute sparse shapes for each stage
for ds_stride in self.downsample_stride:
last_sparse_shape = self.sparse_shape_list[-1]
self.sparse_shape_list.append(
(ceil(last_sparse_shape[0] / ds_stride[0]),
ceil(last_sparse_shape[1] / ds_stride[1]),
ceil(last_sparse_shape[2] / ds_stride[2])))
# position embedding layers
self.posembed_layers = nn.ModuleList()
for i in range(len(self.set_info)):
input_dim = 3 if self.sparse_shape_list[i][-1] > 1 else 2
stage_posembed_layers = nn.ModuleList()
for j in range(self.set_info[i][1]):
block_posembed_layers = nn.ModuleList()
for s in range(self.num_shifts[i]):
block_posembed_layers.append(
PositionEmbeddingLearned(input_dim, self.dim_model[i]))
stage_posembed_layers.append(block_posembed_layers)
self.posembed_layers.append(stage_posembed_layers)
def forward(self, batch_dict):
'''
Args:
bacth_dict (dict):
The dict contains the following keys
- voxel_features (Tensor[float]): Voxel features after VFE
with shape (N, dim_model[0]),
where N is the number of input voxels.
- voxel_coords (Tensor[int]): Shape of (N, 4), corresponding
voxel coordinates of each voxels.
Each row is (batch_id, z, y, x).
- ...
Returns:
voxel_info (dict):
The dict contains the following keys
- voxel_coors_stage{i} (Tensor[int]): Shape of (N_i, 4). N is
the number of voxels in stage_i.
Each row is (batch_id, z, y, x).
- set_voxel_inds_stage{i}_shift{j} (Tensor[int]): Set partition
index with shape (2, set_num, set_info[i][0]).
2 indicates x-axis partition and y-axis partition.
- set_voxel_mask_stage{i}_shift{i} (Tensor[bool]): Key mask
used in set attention with shape
(2, set_num, set_info[i][0]).
- pos_embed_stage{i}_block{i}_shift{i} (Tensor[float]):
Position embedding vectors with shape (N_i, dim_model[i]).
N_i is the number of remain voxels in stage_i;
- pooling_mapping_index_stage{i} (Tensor[int]): Pooling region
index used in pooling operation between stage_{i-1}
and stage_{i} with shape (N_{i-1}).
- pooling_index_in_pool_stage{i} (Tensor[int]): Index inner
region with shape (N_{i-1}). Combined with
pooling_mapping_index_stage{i}, we can map each voxel in
satge_{i-1} to pooling_preholder_feats_stage{i}, which
are input of downsample operation.
- pooling_preholder_feats_stage{i} (Tensor[int]): Preholder
features initial with value 0.
Shape of (N_{i}, downsample_stride[i-1].prob(),
d_moel[i-1]), where prob() returns the product of
all elements.
- ...
'''
voxel_feats = batch_dict['voxel_features']
voxel_coors = batch_dict['voxel_coords'].long()
voxel_info = {}
voxel_info['voxel_feats_stage0'] = voxel_feats.clone()
voxel_info['voxel_coors_stage0'] = voxel_coors.clone()
for stage_id in range(self.stage_num):
# window partition of corresponding stage-map
voxel_info = self.window_partition(voxel_info, stage_id)
# generate set id of corresponding stage-map
voxel_info = self.get_set(voxel_info, stage_id)
for block_id in range(self.set_info[stage_id][1]):
for shift_id in range(self.num_shifts[stage_id]):
layer_name = f'pos_embed_stage{stage_id}_block{block_id}_shift{shift_id}' # noqa: E501
pos_name = f'coors_in_win_stage{stage_id}_shift{shift_id}'
voxel_info[layer_name] = self.get_pos_embed(
voxel_info[pos_name], stage_id, block_id, shift_id)
# compute pooling information
if stage_id < self.stage_num - 1:
voxel_info = self.subm_pooling(voxel_info, stage_id)
return voxel_info
@torch.no_grad()
def subm_pooling(self, voxel_info, stage_id):
# x,y,z stride
cur_stage_downsample = self.downsample_stride[stage_id]
# batch_win_coords is from 1 of x, y
batch_win_inds, _, index_in_win, batch_win_coors = get_pooling_index(
voxel_info[f'voxel_coors_stage{stage_id}'],
self.sparse_shape_list[stage_id], cur_stage_downsample)
# compute pooling mapping index
unique_batch_win_inds, contiguous_batch_win_inds = torch.unique(
batch_win_inds, return_inverse=True)
voxel_info[
f'pooling_mapping_index_stage{stage_id+1}'] = \
contiguous_batch_win_inds
# generate empty placeholder features
placeholder_prepool_feats = voxel_info['voxel_feats_stage0'].new_zeros(
(len(unique_batch_win_inds),
torch.prod(torch.IntTensor(cur_stage_downsample)).item(),
self.dim_model[stage_id]))
voxel_info[f'pooling_index_in_pool_stage{stage_id+1}'] = index_in_win
voxel_info[
f'pooling_preholder_feats_stage{stage_id+1}'] = \
placeholder_prepool_feats
# compute pooling coordinates
unique, inverse = unique_batch_win_inds.clone(
), contiguous_batch_win_inds.clone()
perm = torch.arange(
inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
pool_coors = batch_win_coors[perm]
voxel_info[f'voxel_coors_stage{stage_id+1}'] = pool_coors
return voxel_info
def get_set(self, voxel_info, stage_id):
'''
This is one of the core operation of DSVT.
Given voxels' window ids and relative-coords inner window, we partition
them into window-bounded and size-equivalent local sets. To make it
clear and easy to follow, we do not use loop to process two shifts.
Args:
voxel_info (dict):
The dict contains the following keys
- batch_win_inds_s{i} (Tensor[float]): Windows indices of each
voxel with shape (N), computed by 'window_partition'.
- coors_in_win_shift{i} (Tensor[int]): Relative-coords inner
window of each voxel with shape (N, 3), computed by
'window_partition'. Each row is (z, y, x).
- ...
Returns:
See from 'forward' function.
'''
batch_win_inds_shift0 = voxel_info[
f'batch_win_inds_stage{stage_id}_shift0']
coors_in_win_shift0 = voxel_info[
f'coors_in_win_stage{stage_id}_shift0']
set_voxel_inds_shift0 = self.get_set_single_shift(
batch_win_inds_shift0,
stage_id,
shift_id=0,
coors_in_win=coors_in_win_shift0)
voxel_info[
f'set_voxel_inds_stage{stage_id}_shift0'] = set_voxel_inds_shift0
# compute key masks, voxel duplication must happen continuously
prefix_set_voxel_inds_s0 = torch.roll(
set_voxel_inds_shift0.clone(), shifts=1, dims=-1)
prefix_set_voxel_inds_s0[:, :, 0] = -1
set_voxel_mask_s0 = (set_voxel_inds_shift0 == prefix_set_voxel_inds_s0)
voxel_info[
f'set_voxel_mask_stage{stage_id}_shift0'] = set_voxel_mask_s0
batch_win_inds_shift1 = voxel_info[
f'batch_win_inds_stage{stage_id}_shift1']
coors_in_win_shift1 = voxel_info[
f'coors_in_win_stage{stage_id}_shift1']
set_voxel_inds_shift1 = self.get_set_single_shift(
batch_win_inds_shift1,
stage_id,
shift_id=1,
coors_in_win=coors_in_win_shift1)
voxel_info[
f'set_voxel_inds_stage{stage_id}_shift1'] = set_voxel_inds_shift1
# compute key masks, voxel duplication must happen continuously
prefix_set_voxel_inds_s1 = torch.roll(
set_voxel_inds_shift1.clone(), shifts=1, dims=-1)
prefix_set_voxel_inds_s1[:, :, 0] = -1
set_voxel_mask_s1 = (set_voxel_inds_shift1 == prefix_set_voxel_inds_s1)
voxel_info[
f'set_voxel_mask_stage{stage_id}_shift1'] = set_voxel_mask_s1
return voxel_info
def get_set_single_shift(self,
batch_win_inds,
stage_id,
shift_id=None,
coors_in_win=None):
device = batch_win_inds.device
# the number of voxels assigned to a set
voxel_num_set = self.set_info[stage_id][0]
# max number of voxels in a window
max_voxel = self.window_shape[stage_id][shift_id][
0] * self.window_shape[stage_id][shift_id][1] * self.window_shape[
stage_id][shift_id][2]
# get unique set indices
contiguous_win_inds = torch.unique(
batch_win_inds, return_inverse=True)[1]
voxelnum_per_win = torch.bincount(contiguous_win_inds)
win_num = voxelnum_per_win.shape[0]
setnum_per_win_float = voxelnum_per_win / voxel_num_set
setnum_per_win = torch.ceil(setnum_per_win_float).long()
set_win_inds, set_inds_in_win = get_continous_inds(setnum_per_win)
# compution of Eq.3 in 'DSVT: Dynamic Sparse Voxel Transformer with
# Rotated Sets' - https://arxiv.org/abs/2301.06051,
# for each window, we can get voxel indices belong to different sets.
offset_idx = set_inds_in_win[:, None].repeat(
1, voxel_num_set) * voxel_num_set
base_idx = torch.arange(0, voxel_num_set, 1, device=device)
base_select_idx = offset_idx + base_idx
base_select_idx = base_select_idx * voxelnum_per_win[
set_win_inds][:, None]
base_select_idx = base_select_idx.double() / (
setnum_per_win[set_win_inds] * voxel_num_set)[:, None].double()
base_select_idx = torch.floor(base_select_idx)
# obtain unique indices in whole space
select_idx = base_select_idx
select_idx = select_idx + set_win_inds.view(-1, 1) * max_voxel
# this function will return unordered inner window indices of
# each voxel
inner_voxel_inds = get_inner_win_inds_cuda(contiguous_win_inds)
global_voxel_inds = contiguous_win_inds * max_voxel + inner_voxel_inds
_, order1 = torch.sort(global_voxel_inds)
# get y-axis partition results
global_voxel_inds_sorty = contiguous_win_inds * max_voxel + \
coors_in_win[:, 1] * self.window_shape[stage_id][shift_id][0] * \
self.window_shape[stage_id][shift_id][2] + coors_in_win[:, 2] * \
self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:, 0]
_, order2 = torch.sort(global_voxel_inds_sorty)
inner_voxel_inds_sorty = -torch.ones_like(inner_voxel_inds)
inner_voxel_inds_sorty.scatter_(
dim=0, index=order2, src=inner_voxel_inds[order1]
) # get y-axis ordered inner window indices of each voxel
voxel_inds_in_batch_sorty = inner_voxel_inds_sorty + max_voxel * \
contiguous_win_inds
voxel_inds_padding_sorty = -1 * torch.ones(
(win_num * max_voxel), dtype=torch.long, device=device)
voxel_inds_padding_sorty[voxel_inds_in_batch_sorty] = torch.arange(
0,
voxel_inds_in_batch_sorty.shape[0],
dtype=torch.long,
device=device)
set_voxel_inds_sorty = voxel_inds_padding_sorty[select_idx.long()]
# get x-axis partition results
global_voxel_inds_sortx = contiguous_win_inds * max_voxel + \
coors_in_win[:, 2] * self.window_shape[stage_id][shift_id][1] * \
self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:, 1] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:, 0]
_, order2 = torch.sort(global_voxel_inds_sortx)
inner_voxel_inds_sortx = -torch.ones_like(inner_voxel_inds)
inner_voxel_inds_sortx.scatter_(
dim=0, index=order2, src=inner_voxel_inds[order1]
) # get x-axis ordered inner window indices of each voxel
voxel_inds_in_batch_sortx = inner_voxel_inds_sortx + max_voxel * \
contiguous_win_inds
voxel_inds_padding_sortx = -1 * torch.ones(
(win_num * max_voxel), dtype=torch.long, device=device)
voxel_inds_padding_sortx[voxel_inds_in_batch_sortx] = torch.arange(
0,
voxel_inds_in_batch_sortx.shape[0],
dtype=torch.long,
device=device)
set_voxel_inds_sortx = voxel_inds_padding_sortx[select_idx.long()]
all_set_voxel_inds = torch.stack(
(set_voxel_inds_sorty, set_voxel_inds_sortx), dim=0)
return all_set_voxel_inds
@torch.no_grad()
def window_partition(self, voxel_info, stage_id):
for i in range(2):
batch_win_inds, coors_in_win = get_window_coors(
voxel_info[f'voxel_coors_stage{stage_id}'],
self.sparse_shape_list[stage_id],
self.window_shape[stage_id][i], i == 1,
self.shift_list[stage_id][i])
voxel_info[
f'batch_win_inds_stage{stage_id}_shift{i}'] = batch_win_inds
voxel_info[f'coors_in_win_stage{stage_id}_shift{i}'] = coors_in_win
return voxel_info
def get_pos_embed(self, coors_in_win, stage_id, block_id, shift_id):
'''
Args:
coors_in_win: shape=[N, 3], order: z, y, x
'''
# [N,]
window_shape = self.window_shape[stage_id][shift_id]
embed_layer = self.posembed_layers[stage_id][block_id][shift_id]
if len(window_shape) == 2:
ndim = 2
win_x, win_y = window_shape
win_z = 0
elif window_shape[-1] == 1:
ndim = 2
win_x, win_y = window_shape[:2]
win_z = 0
else:
win_x, win_y, win_z = window_shape
ndim = 3
assert coors_in_win.size(1) == 3
z, y, x = coors_in_win[:, 0] - win_z / 2,\
coors_in_win[:, 1] - win_y / 2,\
coors_in_win[:, 2] - win_x / 2
if self.normalize_pos:
x = x / win_x * 2 * 3.1415 # [-pi, pi]
y = y / win_y * 2 * 3.1415 # [-pi, pi]
z = z / win_z * 2 * 3.1415 # [-pi, pi]
if ndim == 2:
location = torch.stack((x, y), dim=-1)
else:
location = torch.stack((x, y, z), dim=-1)
pos_embed = embed_layer(location)
return pos_embed
# modified from https://github.com/Haiyang-W/DSVT
import torch
import torch.nn as nn
from mmdet3d.registry import MODELS
from .dsvt_input_layer import DSVTInputLayer
@MODELS.register_module()
class DSVTMiddleEncoder(nn.Module):
'''Dynamic Sparse Voxel Transformer Backbone.
Args:
INPUT_LAYER: Config of input layer, which converts the output of vfe
to dsvt input.
block_name (list[string]): Name of blocks for each stage. Length:
stage_num.
set_info (list[list[int, int]]): A list of set config for each stage.
Eelement i contains
[set_size, block_num], where set_size is the number of voxel in a
set and block_num is the number of blocks for stage i. Length:
stage_num.
dim_model (list[int]): Number of input channels for each stage.
Length: stage_num.
nhead (list[int]): Number of attention heads for each stage.
Length: stage_num.
dim_feedforward (list[int]): Dimensions of the feedforward network in
set attention for each stage. Length: stage num.
dropout (float): Drop rate of set attention.
activation (string): Name of activation layer in set attention.
reduction_type (string): Pooling method between stages.
One of: "attention", "maxpool", "linear".
output_shape (tuple[int, int]): Shape of output bev feature.
conv_out_channel (int): Number of output channels.
'''
def __init__(
self,
input_layer=dict(
sparse_shape=[468, 468, 1],
downsample_stride=[],
dim_model=[192],
set_info=[[36, 4]],
window_shape=[[12, 12, 1]],
hybrid_factor=[2, 2, 1], # x, y, z
shifts_list=[[[0, 0, 0], [6, 6, 0]]],
normalize_pos=False),
stage_num=1,
output_shape=[468, 468],
reduction_type='attention',
downsample_stride=[],
set_info=[[36, 4]],
dim_model=[192],
dim_feedforward=[384],
nhead=[8],
conv_out_channel=192,
dropout=0.,
activation='gelu'):
super().__init__()
self.input_layer = DSVTInputLayer(**input_layer)
self.reduction_type = reduction_type
# Sparse Regional Attention Blocks
for stage_id in range(stage_num):
num_blocks_this_stage = set_info[stage_id][-1]
dmodel_this_stage = dim_model[stage_id]
dfeed_this_stage = dim_feedforward[stage_id]
num_head_this_stage = nhead[stage_id]
block_list = []
norm_list = []
for i in range(num_blocks_this_stage):
block_list.append(
DSVTBlock(
dmodel_this_stage,
num_head_this_stage,
dfeed_this_stage,
dropout,
activation,
batch_first=True))
norm_list.append(nn.LayerNorm(dmodel_this_stage))
self.__setattr__(f'stage_{stage_id}', nn.ModuleList(block_list))
self.__setattr__(f'residual_norm_stage_{stage_id}',
nn.ModuleList(norm_list))
# apply pooling except the last stage
if stage_id < stage_num - 1:
downsample_window = downsample_stride[stage_id]
dmodel_next_stage = dim_model[stage_id + 1]
pool_volume = torch.IntTensor(downsample_window).prod().item()
if self.reduction_type == 'linear':
cat_feat_dim = dmodel_this_stage * torch.IntTensor(
downsample_window).prod().item()
self.__setattr__(
f'stage_{stage_id}_reduction',
StageReductionBlock(cat_feat_dim, dmodel_next_stage))
elif self.reduction_type == 'maxpool':
self.__setattr__(f'stage_{stage_id}_reduction',
torch.nn.MaxPool1d(pool_volume))
elif self.reduction_type == 'attention':
self.__setattr__(
f'stage_{stage_id}_reduction',
StageReductionAttBlock(dmodel_this_stage, pool_volume))
else:
raise NotImplementedError
self.num_shifts = [2] * stage_num
self.output_shape = output_shape
self.stage_num = stage_num
self.set_info = set_info
self.num_point_features = conv_out_channel
self._reset_parameters()
def forward(self, batch_dict):
'''
Args:
bacth_dict (dict):
The dict contains the following keys
- voxel_features (Tensor[float]): Voxel features after VFE.
Shape of (N, dim_model[0]),
where N is the number of input voxels.
- voxel_coords (Tensor[int]): Shape of (N, 4), corresponding
voxel coordinates of each voxels.
Each row is (batch_id, z, y, x).
- ...
Returns:
bacth_dict (dict):
The dict contains the following keys
- pillar_features (Tensor[float]):
- voxel_coords (Tensor[int]):
- ...
'''
voxel_info = self.input_layer(batch_dict)
voxel_feat = voxel_info['voxel_feats_stage0']
set_voxel_inds_list = [[
voxel_info[f'set_voxel_inds_stage{s}_shift{i}']
for i in range(self.num_shifts[s])
] for s in range(self.stage_num)]
set_voxel_masks_list = [[
voxel_info[f'set_voxel_mask_stage{s}_shift{i}']
for i in range(self.num_shifts[s])
] for s in range(self.stage_num)]
pos_embed_list = [[[
voxel_info[f'pos_embed_stage{s}_block{b}_shift{i}']
for i in range(self.num_shifts[s])
] for b in range(self.set_info[s][1])] for s in range(self.stage_num)]
pooling_mapping_index = [
voxel_info[f'pooling_mapping_index_stage{s+1}']
for s in range(self.stage_num - 1)
]
pooling_index_in_pool = [
voxel_info[f'pooling_index_in_pool_stage{s+1}']
for s in range(self.stage_num - 1)
]
pooling_preholder_feats = [
voxel_info[f'pooling_preholder_feats_stage{s+1}']
for s in range(self.stage_num - 1)
]
output = voxel_feat
block_id = 0
for stage_id in range(self.stage_num):
block_layers = self.__getattr__(f'stage_{stage_id}')
residual_norm_layers = self.__getattr__(
f'residual_norm_stage_{stage_id}')
for i in range(len(block_layers)):
block = block_layers[i]
residual = output.clone()
output = block(
output,
set_voxel_inds_list[stage_id],
set_voxel_masks_list[stage_id],
pos_embed_list[stage_id][i],
block_id=block_id)
output = residual_norm_layers[i](output + residual)
block_id += 1
if stage_id < self.stage_num - 1:
# pooling
prepool_features = pooling_preholder_feats[stage_id].type_as(
output)
pooled_voxel_num = prepool_features.shape[0]
pool_volume = prepool_features.shape[1]
prepool_features[pooling_mapping_index[stage_id],
pooling_index_in_pool[stage_id]] = output
prepool_features = prepool_features.view(
prepool_features.shape[0], -1)
if self.reduction_type == 'linear':
output = self.__getattr__(f'stage_{stage_id}_reduction')(
prepool_features)
elif self.reduction_type == 'maxpool':
prepool_features = prepool_features.view(
pooled_voxel_num, pool_volume, -1).permute(0, 2, 1)
output = self.__getattr__(f'stage_{stage_id}_reduction')(
prepool_features).squeeze(-1)
elif self.reduction_type == 'attention':
prepool_features = prepool_features.view(
pooled_voxel_num, pool_volume, -1).permute(0, 2, 1)
key_padding_mask = torch.zeros(
(pooled_voxel_num,
pool_volume)).to(prepool_features.device).int()
output = self.__getattr__(f'stage_{stage_id}_reduction')(
prepool_features, key_padding_mask)
else:
raise NotImplementedError
batch_dict['pillar_features'] = batch_dict['voxel_features'] = output
batch_dict['voxel_coords'] = voxel_info[
f'voxel_coors_stage{self.stage_num - 1}']
return batch_dict
def _reset_parameters(self):
for name, p in self.named_parameters():
if p.dim() > 1 and 'scaler' not in name:
nn.init.xavier_uniform_(p)
class DSVTBlock(nn.Module):
"""Consist of two encoder layer, shift and shift back."""
def __init__(self,
dim_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
batch_first=True):
super().__init__()
encoder_1 = DSVTEncoderLayer(dim_model, nhead, dim_feedforward,
dropout, activation, batch_first)
encoder_2 = DSVTEncoderLayer(dim_model, nhead, dim_feedforward,
dropout, activation, batch_first)
self.encoder_list = nn.ModuleList([encoder_1, encoder_2])
def forward(
self,
src,
set_voxel_inds_list,
set_voxel_masks_list,
pos_embed_list,
block_id,
):
num_shifts = 2
output = src
for i in range(num_shifts):
set_id = i
shift_id = block_id % 2
pos_embed_id = i
set_voxel_inds = set_voxel_inds_list[shift_id][set_id]
set_voxel_masks = set_voxel_masks_list[shift_id][set_id]
pos_embed = pos_embed_list[pos_embed_id]
layer = self.encoder_list[i]
output = layer(output, set_voxel_inds, set_voxel_masks, pos_embed)
return output
class DSVTEncoderLayer(nn.Module):
def __init__(self,
dim_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
batch_first=True,
mlp_dropout=0):
super().__init__()
self.win_attn = SetAttention(dim_model, nhead, dropout,
dim_feedforward, activation, batch_first,
mlp_dropout)
self.norm = nn.LayerNorm(dim_model)
self.dim_model = dim_model
def forward(self, src, set_voxel_inds, set_voxel_masks, pos=None):
identity = src
src = self.win_attn(src, pos, set_voxel_masks, set_voxel_inds)
src = src + identity
src = self.norm(src)
return src
class SetAttention(nn.Module):
def __init__(self,
dim_model,
nhead,
dropout,
dim_feedforward=2048,
activation='relu',
batch_first=True,
mlp_dropout=0):
super().__init__()
self.nhead = nhead
if batch_first:
self.self_attn = nn.MultiheadAttention(
dim_model, nhead, dropout=dropout, batch_first=batch_first)
else:
self.self_attn = nn.MultiheadAttention(
dim_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(dim_model, dim_feedforward)
self.dropout = nn.Dropout(mlp_dropout)
self.linear2 = nn.Linear(dim_feedforward, dim_model)
self.dim_model = dim_model
self.norm1 = nn.LayerNorm(dim_model)
self.norm2 = nn.LayerNorm(dim_model)
self.dropout1 = nn.Identity()
self.dropout2 = nn.Identity()
self.activation = _get_activation_fn(activation)
def forward(self, src, pos=None, key_padding_mask=None, voxel_inds=None):
'''
Args:
src (Tensor[float]): Voxel features with shape (N, C), where N is
the number of voxels.
pos (Tensor[float]): Position embedding vectors with shape (N, C).
key_padding_mask (Tensor[bool]): Mask for redundant voxels
within set. Shape of (set_num, set_size).
voxel_inds (Tensor[int]): Voxel indices for each set.
Shape of (set_num, set_size).
Returns:
src (Tensor[float]): Voxel features.
'''
set_features = src[voxel_inds]
if pos is not None:
set_pos = pos[voxel_inds]
else:
set_pos = None
if pos is not None:
query = set_features + set_pos
key = set_features + set_pos
value = set_features
if key_padding_mask is not None:
src2 = self.self_attn(query, key, value, key_padding_mask)[0]
else:
src2 = self.self_attn(query, key, value)[0]
# map voxel features from set space to voxel space:
# (set_num, set_size, C) --> (N, C)
flatten_inds = voxel_inds.reshape(-1)
unique_flatten_inds, inverse = torch.unique(
flatten_inds, return_inverse=True)
perm = torch.arange(
inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique_flatten_inds.size(0)).scatter_(
0, inverse, perm)
src2 = src2.reshape(-1, self.dim_model)[perm]
# FFN layer
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class StageReductionBlock(nn.Module):
def __init__(self, input_channel, output_channel):
super().__init__()
self.linear1 = nn.Linear(input_channel, output_channel, bias=False)
self.norm = nn.LayerNorm(output_channel)
def forward(self, x):
src = x
src = self.norm(self.linear1(x))
return src
class StageReductionAttBlock(nn.Module):
def __init__(self, input_channel, pool_volume):
super().__init__()
self.pool_volume = pool_volume
self.query_func = torch.nn.MaxPool1d(pool_volume)
self.norm = nn.LayerNorm(input_channel)
self.self_attn = nn.MultiheadAttention(
input_channel, 8, batch_first=True)
self.pos_embedding = nn.Parameter(
torch.randn(pool_volume, input_channel))
nn.init.normal_(self.pos_embedding, std=.01)
def forward(self, x, key_padding_mask):
# x: [voxel_num, c_dim, pool_volume]
src = self.query_func(x).permute(0, 2, 1) # voxel_num, 1, c_dim
key = value = x.permute(0, 2, 1)
key = key + self.pos_embedding.unsqueeze(0).repeat(src.shape[0], 1, 1)
query = src.clone()
output = self.self_attn(query, key, value, key_padding_mask)[0]
src = self.norm(output + src).squeeze(1)
return src
def _get_activation_fn(activation):
"""Return an activation function given a string."""
if activation == 'relu':
return torch.nn.functional.relu
if activation == 'gelu':
return torch.nn.functional.gelu
if activation == 'glu':
return torch.nn.functional.glu
raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
# modified from https://github.com/Haiyang-W/DSVT
import torch
import torch.nn as nn
import torch_scatter
from mmdet3d.registry import MODELS
class PFNLayerV2(nn.Module):
def __init__(self,
in_channels,
out_channels,
use_norm=True,
last_layer=False):
super().__init__()
self.last_vfe = last_layer
self.use_norm = use_norm
if not self.last_vfe:
out_channels = out_channels // 2
if self.use_norm:
self.linear = nn.Linear(in_channels, out_channels, bias=False)
self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01)
else:
self.linear = nn.Linear(in_channels, out_channels, bias=True)
self.relu = nn.ReLU()
def forward(self, inputs, unq_inv):
x = self.linear(inputs)
x = self.norm(x) if self.use_norm else x
x = self.relu(x)
x_max = torch_scatter.scatter_max(x, unq_inv, dim=0)[0]
if self.last_vfe:
return x_max
else:
x_concatenated = torch.cat([x, x_max[unq_inv, :]], dim=1)
return x_concatenated
@MODELS.register_module()
class DynamicPillarVFE3D(nn.Module):
"""The difference between `DynamicPillarVFE3D` and `DynamicPillarVFE` is
that the voxel in this module is along 3 dims: (x, y, z)."""
def __init__(self, with_distance, use_absolute_xyz, use_norm, num_filters,
num_point_features, voxel_size, grid_size, point_cloud_range):
super().__init__()
self.use_norm = use_norm
self.with_distance = with_distance
self.use_absolute_xyz = use_absolute_xyz
num_point_features += 6 if self.use_absolute_xyz else 3
if self.with_distance:
num_point_features += 1
self.num_filters = num_filters
assert len(self.num_filters) > 0
num_filters = [num_point_features] + list(self.num_filters)
pfn_layers = []
for i in range(len(num_filters) - 1):
in_filters = num_filters[i]
out_filters = num_filters[i + 1]
pfn_layers.append(
PFNLayerV2(
in_filters,
out_filters,
self.use_norm,
last_layer=(i >= len(num_filters) - 2)))
self.pfn_layers = nn.ModuleList(pfn_layers)
self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2]
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
self.scale_yz = grid_size[1] * grid_size[2]
self.scale_z = grid_size[2]
self.grid_size = torch.tensor(grid_size).cuda()
self.voxel_size = torch.tensor(voxel_size).cuda()
self.point_cloud_range = torch.tensor(point_cloud_range).cuda()
def get_output_feature_dim(self):
return self.num_filters[-1]
def forward(self, batch_dict, **kwargs):
"""Forward function.
Args:
batch_dict (dict[list]): Batch input data:
- points [list[Tensor]]: list of batch input points.
Returns:
dict: Voxelization outputs:
- points:
- pillar_features/voxel_features:
- voxel_coords
"""
batch_prefix_points = []
for batch_idx, points in enumerate(batch_dict['points']):
prefix_batch_idx = torch.Tensor([batch_idx
]).tile(points.size(0),
1).to(points)
prefix_points = torch.cat((prefix_batch_idx, points),
dim=1) # (batch_idx, x, y, z, i, e)
batch_prefix_points.append(prefix_points)
points = torch.cat(batch_prefix_points, dim=0)
del prefix_points, batch_prefix_points
points_coords = torch.floor(
(points[:, [1, 2, 3]] - self.point_cloud_range[[0, 1, 2]]) /
self.voxel_size[[0, 1, 2]]).int()
mask = ((points_coords >= 0) &
(points_coords < self.grid_size[[0, 1, 2]])).all(dim=1)
points = points[mask]
points_coords = points_coords[mask]
points_xyz = points[:, [1, 2, 3]].contiguous()
merge_coords = points[:, 0].int() * self.scale_xyz + \
points_coords[:, 0] * self.scale_yz + \
points_coords[:, 1] * self.scale_z + points_coords[:, 2]
unq_coords, unq_inv, unq_cnt = torch.unique(
merge_coords, return_inverse=True, return_counts=True, dim=0)
points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
f_cluster = points_xyz - points_mean[unq_inv, :]
f_center = torch.zeros_like(points_xyz)
f_center[:, 0] = points_xyz[:, 0] - (
points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x +
self.x_offset)
f_center[:, 1] = points_xyz[:, 1] - (
points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y +
self.y_offset)
# f_center[:, 2] = points_xyz[:, 2] - self.z_offset
f_center[:, 2] = points_xyz[:, 2] - (
points_coords[:, 2].to(points_xyz.dtype) * self.voxel_z +
self.z_offset)
if self.use_absolute_xyz:
features = [points[:, 1:], f_cluster, f_center]
else:
features = [points[:, 4:], f_cluster, f_center]
if self.with_distance:
points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True)
features.append(points_dist)
features = torch.cat(features, dim=-1)
for pfn in self.pfn_layers:
features = pfn(features, unq_inv)
# generate voxel coordinates
unq_coords = unq_coords.int()
voxel_coords = torch.stack(
(unq_coords // self.scale_xyz,
(unq_coords % self.scale_xyz) // self.scale_yz,
(unq_coords % self.scale_yz) // self.scale_z,
unq_coords % self.scale_z),
dim=1)
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]
batch_dict['pillar_features'] = batch_dict['voxel_features'] = features
batch_dict['voxel_coords'] = voxel_coords
return batch_dict
# modified from https://github.com/Haiyang-W/DSVT
import torch
import torch.nn as nn
from mmdet3d.registry import MODELS
@MODELS.register_module()
class PointPillarsScatter3D(nn.Module):
"""The difference between `PointPillarsScatter3D` and `PointPillarsScatter`
is that the voxel in this module is along 3 dims: (x, y, z)."""
def __init__(self, output_shape, num_bev_feats, **kwargs):
super().__init__()
self.nx, self.ny, self.nz = output_shape
self.num_bev_feats = num_bev_feats
self.num_bev_feats_ori = num_bev_feats // self.nz
def forward(self, batch_dict, **kwargs):
pillar_features, coords = batch_dict['pillar_features'], batch_dict[
'voxel_coords']
batch_spatial_features = []
batch_size = coords[:, 0].max().int().item() + 1
for batch_idx in range(batch_size):
spatial_feature = torch.zeros(
self.num_bev_feats_ori,
self.nz * self.nx * self.ny,
dtype=pillar_features.dtype,
device=pillar_features.device)
batch_mask = coords[:, 0] == batch_idx
this_coords = coords[batch_mask, :]
indices = this_coords[:, 1] * self.ny * self.nx + \
this_coords[:, 2] * self.nx + this_coords[:, 3]
indices = indices.type(torch.long)
pillars = pillar_features[batch_mask, :]
pillars = pillars.t()
spatial_feature[:, indices] = pillars
batch_spatial_features.append(spatial_feature)
batch_spatial_features = torch.stack(batch_spatial_features, 0)
batch_spatial_features = batch_spatial_features.view(
batch_size, self.num_bev_feats_ori * self.nz, self.ny, self.nx)
batch_dict['spatial_features'] = batch_spatial_features
return batch_dict
import torch
from torch.autograd import Function
try:
from . import ingroup_inds_cuda
# import ingroup_indices
except ImportError:
ingroup_indices = None
print('Can not import ingroup indices')
ingroup_indices = ingroup_inds_cuda
class IngroupIndicesFunction(Function):
@staticmethod
def forward(ctx, group_inds):
out_inds = torch.zeros_like(group_inds) - 1
ingroup_indices.forward(group_inds, out_inds)
ctx.mark_non_differentiable(out_inds)
return out_inds
@staticmethod
def backward(ctx, g):
return None
ingroup_inds = IngroupIndicesFunction.apply
#pragma once
#include <stdio.h>
#define CHECK_CALL(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
#include <assert.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void ingroup_inds_launcher(
const long *group_inds_data,
long *out_inds_data,
int N,
int max_group_id
);
void ingroup_inds_gpu(
at::Tensor group_inds,
at::Tensor out_inds
);
void ingroup_inds_gpu(
at::Tensor group_inds,
at::Tensor out_inds
) {
CHECK_INPUT(group_inds);
CHECK_INPUT(out_inds);
int N = group_inds.size(0);
int max_group_id = group_inds.max().item().toLong();
long *group_inds_data = group_inds.data_ptr<long>();
long *out_inds_data = out_inds.data_ptr<long>();
ingroup_inds_launcher(
group_inds_data,
out_inds_data,
N,
max_group_id
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &ingroup_inds_gpu, "cuda version of get_inner_win_inds of SST");
}
#include <assert.h>
#include <vector>
#include <math.h>
#include <stdio.h>
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <torch/types.h>
#include "cuda_fp16.h"
// #include "error.cuh"
#define CHECK_CALL(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG
// #define ASSERTION
__global__ void ingroup_inds_kernel(
const long *group_inds,
long *out_inds,
int *ingroup_counter,
int N
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
long this_group_id = group_inds[idx];
int cnt = atomicAdd(&ingroup_counter[this_group_id], 1);
out_inds[idx] = cnt;
}
void ingroup_inds_launcher(
const long *group_inds,
long *out_inds,
int N,
int max_group_id
) {
int *ingroup_counter = NULL;
CHECK_CALL(cudaMalloc(&ingroup_counter, (max_group_id + 1) * sizeof(int)));
CHECK_CALL(cudaMemset(ingroup_counter, 0, (max_group_id + 1) * sizeof(int)));
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
ingroup_inds_kernel<<<blocks, threads>>>(
group_inds,
out_inds,
ingroup_counter,
N
);
cudaFree(ingroup_counter);
#ifdef DEBUG
CHECK_CALL(cudaGetLastError());
CHECK_CALL(cudaDeviceSynchronize());
#endif
return;
}
# modified from https://github.com/Haiyang-W/DSVT
import warnings
from typing import Optional, Sequence, Tuple
from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet3d.utils import OptMultiConfig
class BasicResBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
padding: int = 1,
downsample: bool = False,
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=padding,
bias=False)
self.bn1 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
self.relu2 = nn.ReLU()
self.downsample = downsample
if self.downsample:
self.downsample_layer = nn.Sequential(
nn.Conv2d(
inplanes,
planes,
kernel_size=1,
stride=stride,
padding=0,
bias=False),
nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01))
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
identity = self.downsample_layer(x)
out += identity
out = self.relu2(out)
return out
@MODELS.register_module()
class ResSECOND(BaseModule):
"""Backbone network for DSVT. The difference between `ResSECOND` and
`SECOND` is that the basic block in this module contains residual layers.
Args:
in_channels (int): Input channels.
out_channels (list[int]): Output channels for multi-scale feature maps.
blocks_nums (list[int]): Number of blocks in each stage.
layer_strides (list[int]): Strides of each stage.
norm_cfg (dict): Config dict of normalization layers.
conv_cfg (dict): Config dict of convolutional layers.
"""
def __init__(self,
in_channels: int = 128,
out_channels: Sequence[int] = [128, 128, 256],
blocks_nums: Sequence[int] = [1, 2, 2],
layer_strides: Sequence[int] = [2, 2, 2],
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None) -> None:
super(ResSECOND, self).__init__(init_cfg=init_cfg)
assert len(layer_strides) == len(blocks_nums)
assert len(out_channels) == len(blocks_nums)
in_filters = [in_channels, *out_channels[:-1]]
blocks = []
for i, block_num in enumerate(blocks_nums):
cur_layers = [
BasicResBlock(
in_filters[i],
out_channels[i],
stride=layer_strides[i],
downsample=True)
]
for _ in range(block_num):
cur_layers.append(
BasicResBlock(out_channels[i], out_channels[i]))
blocks.append(nn.Sequential(*cur_layers))
self.blocks = nn.Sequential(*blocks)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
else:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward function.
Args:
x (torch.Tensor): Input with shape (N, C, H, W).
Returns:
tuple[torch.Tensor]: Multi-scale features.
"""
outs = []
for i in range(len(self.blocks)):
x = self.blocks[i](x)
outs.append(x)
return tuple(outs)
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from mmdet3d.models.task_modules import CenterPointBBoxCoder
from mmdet3d.registry import TASK_UTILS
from .ops.ingroup_inds.ingroup_inds_op import ingroup_inds
get_inner_win_inds_cuda = ingroup_inds
class PositionEmbeddingLearned(nn.Module):
"""Absolute pos embedding, learned."""
def __init__(self, input_channel, num_pos_feats):
super().__init__()
self.position_embedding_head = nn.Sequential(
nn.Linear(input_channel, num_pos_feats),
nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True),
nn.Linear(num_pos_feats, num_pos_feats))
def forward(self, xyz):
position_embedding = self.position_embedding_head(xyz)
return position_embedding
@torch.no_grad()
def get_window_coors(coors,
sparse_shape,
window_shape,
do_shift,
shift_list=None,
return_win_coors=False):
if len(window_shape) == 2:
win_shape_x, win_shape_y = window_shape
win_shape_z = sparse_shape[-1]
else:
win_shape_x, win_shape_y, win_shape_z = window_shape
sparse_shape_x, sparse_shape_y, sparse_shape_z = sparse_shape
assert sparse_shape_z < sparse_shape_x, 'Usually holds... in case of wrong order' # noqa: E501
max_num_win_x = int(np.ceil((sparse_shape_x / win_shape_x)) +
1) # plus one here to meet the needs of shift.
max_num_win_y = int(np.ceil((sparse_shape_y / win_shape_y)) +
1) # plus one here to meet the needs of shift.
max_num_win_z = int(np.ceil((sparse_shape_z / win_shape_z)) +
1) # plus one here to meet the needs of shift.
max_num_win_per_sample = max_num_win_x * max_num_win_y * max_num_win_z
if do_shift:
if shift_list is not None:
shift_x, shift_y, shift_z = shift_list[0], shift_list[
1], shift_list[2]
else:
shift_x, shift_y, shift_z = win_shape_x // 2, win_shape_y // 2, win_shape_z // 2 # noqa: E501
else:
if shift_list is not None:
shift_x, shift_y, shift_z = shift_list[0], shift_list[
1], shift_list[2]
else:
shift_x, shift_y, shift_z = win_shape_x, win_shape_y, win_shape_z
# compatibility between 2D window and 3D window
if sparse_shape_z == win_shape_z:
shift_z = 0
shifted_coors_x = coors[:, 3] + shift_x
shifted_coors_y = coors[:, 2] + shift_y
shifted_coors_z = coors[:, 1] + shift_z
win_coors_x = shifted_coors_x // win_shape_x
win_coors_y = shifted_coors_y // win_shape_y
win_coors_z = shifted_coors_z // win_shape_z
if len(window_shape) == 2:
assert (win_coors_z == 0).all()
batch_win_inds = coors[:, 0] * max_num_win_per_sample + \
win_coors_x * max_num_win_y * max_num_win_z + \
win_coors_y * max_num_win_z + win_coors_z
coors_in_win_x = shifted_coors_x % win_shape_x
coors_in_win_y = shifted_coors_y % win_shape_y
coors_in_win_z = shifted_coors_z % win_shape_z
coors_in_win = torch.stack(
[coors_in_win_z, coors_in_win_y, coors_in_win_x], dim=-1)
# coors_in_win = torch.stack([coors_in_win_x, coors_in_win_y], dim=-1)
if return_win_coors:
batch_win_coords = torch.stack([win_coors_z, win_coors_y, win_coors_x],
dim=-1)
return batch_win_inds, coors_in_win, batch_win_coords
return batch_win_inds, coors_in_win
def get_pooling_index(coors, sparse_shape, window_shape):
win_shape_x, win_shape_y, win_shape_z = window_shape
sparse_shape_x, sparse_shape_y, sparse_shape_z = sparse_shape
max_num_win_x = int(np.ceil((sparse_shape_x / win_shape_x)))
max_num_win_y = int(np.ceil((sparse_shape_y / win_shape_y)))
max_num_win_z = int(np.ceil((sparse_shape_z / win_shape_z)))
max_num_win_per_sample = max_num_win_x * max_num_win_y * max_num_win_z
coors_x = coors[:, 3]
coors_y = coors[:, 2]
coors_z = coors[:, 1]
win_coors_x = coors_x // win_shape_x
win_coors_y = coors_y // win_shape_y
win_coors_z = coors_z // win_shape_z
batch_win_inds = coors[:, 0] * max_num_win_per_sample + \
win_coors_x * max_num_win_y * max_num_win_z + \
win_coors_y * max_num_win_z + win_coors_z
coors_in_win_x = coors_x % win_shape_x
coors_in_win_y = coors_y % win_shape_y
coors_in_win_z = coors_z % win_shape_z
coors_in_win = torch.stack(
[coors_in_win_z, coors_in_win_y, coors_in_win_x], dim=-1)
index_in_win = coors_in_win_x * win_shape_y * win_shape_z + \
coors_in_win_y * win_shape_z + coors_in_win_z
batch_win_coords = torch.stack(
[coors[:, 0], win_coors_z, win_coors_y, win_coors_x], dim=-1)
return batch_win_inds, coors_in_win, index_in_win, batch_win_coords
def get_continous_inds(setnum_per_win):
'''
Args:
setnum_per_win (Tensor[int]): Number of sets assigned to each window
with shape (win_num).
Returns:
set_win_inds (Tensor[int]): Window indices of each set with shape
(set_num).
set_inds_in_win (Tensor[int]): Set indices inner window with shape
(set_num).
Examples:
setnum_per_win = torch.tensor([1, 2, 1, 3])
set_inds_in_win = get_continous_inds(setnum_per_win)
# we can get: set_inds_in_win = tensor([0, 0, 1, 0, 0, 1, 2])
'''
set_num = setnum_per_win.sum().item() # set_num = 7
setnum_per_win_cumsum = torch.cumsum(
setnum_per_win, dim=0)[:-1] # [1, 3, 4]
set_win_inds = torch.full((set_num, ), 0, device=setnum_per_win.device)
set_win_inds[setnum_per_win_cumsum] = 1 # [0, 1, 0, 1, 1, 0, 0]
set_win_inds = torch.cumsum(set_win_inds, dim=0) # [0, 1, 1, 2, 3, 3, 3]
roll_set_win_inds_left = torch.roll(set_win_inds,
-1) # [1, 1, 2, 3, 3, 3, 0]
diff = set_win_inds - roll_set_win_inds_left # [-1, 0, -1, -1, 0, 0, 3]
end_pos_mask = diff != 0
template = torch.ones_like(set_win_inds)
template[end_pos_mask] = (setnum_per_win -
1) * -1 # [ 0, 1, -1, 0, 1, 1, -2]
set_inds_in_win = torch.cumsum(template, dim=0) # [0, 1, 0, 0, 1, 2, 0]
set_inds_in_win[end_pos_mask] = setnum_per_win # [1, 1, 2, 1, 1, 2, 3]
set_inds_in_win = set_inds_in_win - 1 # [0, 0, 1, 0, 0, 1, 2]
return set_win_inds, set_inds_in_win
@TASK_UTILS.register_module()
class DSVTBBoxCoder(CenterPointBBoxCoder):
"""Bbox coder for DSVT.
Compared with `CenterPointBBoxCoder`, this coder contains IoU predictions
"""
def __init__(self, *args, **kwargs) -> None:
super(DSVTBBoxCoder, self).__init__(*args, **kwargs)
def decode(self,
heat: Tensor,
rot_sine: Tensor,
rot_cosine: Tensor,
hei: Tensor,
dim: Tensor,
vel: Tensor,
reg: Optional[Tensor] = None,
iou: Optional[Tensor] = None) -> List[Dict[str, Tensor]]:
"""
Args:
heat (torch.Tensor): Heatmap with the shape of [B, N, W, H].
rot_sine (torch.Tensor): Sine of rotation with the shape of
[B, 1, W, H].
rot_cosine (torch.Tensor): Cosine of rotation with the shape of
[B, 1, W, H].
hei (torch.Tensor): Height of the boxes with the shape
of [B, 1, W, H].
dim (torch.Tensor): Dim of the boxes with the shape of
[B, 1, W, H].
vel (torch.Tensor): Velocity with the shape of [B, 1, W, H].
reg (torch.Tensor, optional): Regression value of the boxes in
2D with the shape of [B, 2, W, H]. Default: None.
Returns:
list[dict]: Decoded boxes.
"""
batch, cat, _, _ = heat.size()
scores, inds, clses, ys, xs = self._topk(heat, K=self.max_num)
if reg is not None:
reg = self._transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, self.max_num, 2)
xs = xs.view(batch, self.max_num, 1) + reg[:, :, 0:1]
ys = ys.view(batch, self.max_num, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, self.max_num, 1) + 0.5
ys = ys.view(batch, self.max_num, 1) + 0.5
# rotation value and direction label
rot_sine = self._transpose_and_gather_feat(rot_sine, inds)
rot_sine = rot_sine.view(batch, self.max_num, 1)
rot_cosine = self._transpose_and_gather_feat(rot_cosine, inds)
rot_cosine = rot_cosine.view(batch, self.max_num, 1)
rot = torch.atan2(rot_sine, rot_cosine)
# height in the bev
hei = self._transpose_and_gather_feat(hei, inds)
hei = hei.view(batch, self.max_num, 1)
# dim of the box
dim = self._transpose_and_gather_feat(dim, inds)
dim = dim.view(batch, self.max_num, 3)
# class label
clses = clses.view(batch, self.max_num).float()
scores = scores.view(batch, self.max_num)
xs = xs.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[0] + self.pc_range[0]
ys = ys.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[1] + self.pc_range[1]
if vel is None: # KITTI FORMAT
final_box_preds = torch.cat([xs, ys, hei, dim, rot], dim=2)
else: # exist velocity, nuscene format
vel = self._transpose_and_gather_feat(vel, inds)
vel = vel.view(batch, self.max_num, 2)
final_box_preds = torch.cat([xs, ys, hei, dim, rot, vel], dim=2)
if iou is not None:
iou = self._transpose_and_gather_feat(iou, inds).view(
batch, self.max_num)
final_scores = scores
final_preds = clses
# use score threshold
if self.score_threshold is not None:
thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None:
self.post_center_range = torch.tensor(
self.post_center_range, device=heat.device)
mask = (final_box_preds[..., :3] >=
self.post_center_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <=
self.post_center_range[3:]).all(2)
predictions_dicts = []
for i in range(batch):
cmask = mask[i, :]
if self.score_threshold:
cmask &= thresh_mask[i]
boxes3d = final_box_preds[i, cmask]
scores = final_scores[i, cmask]
labels = final_preds[i, cmask]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels,
}
if iou is not None:
pred_iou = iou[i, cmask]
predictions_dict['iou'] = pred_iou
predictions_dicts.append(predictions_dict)
else:
raise NotImplementedError(
'Need to reorganize output as a batch, only '
'support post_center_range is not None for now!')
return predictions_dicts
import os
from setuptools import setup
import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
def make_cuda_ext(name,
module,
sources,
sources_cuda=[],
extra_args=[],
extra_include_path=[]):
define_macros = []
extra_compile_args = {'cxx': [] + extra_args}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = extra_args + [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_75,code=sm_75',
'-gencode=arch=compute_80,code=sm_80',
'-gencode=arch=compute_86,code=sm_86',
]
sources += sources_cuda
else:
print('Compiling {} without CUDA'.format(name))
extension = CppExtension
return extension(
name='{}.{}'.format(module, name),
sources=[os.path.join(*module.split('.'), p) for p in sources],
include_dirs=extra_include_path,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
if __name__ == '__main__':
setup(
name='dsvt',
ext_modules=[
make_cuda_ext(
name='ingroup_inds_cuda',
module='projects.DSVT.dsvt.ops.ingroup_inds',
sources=[
'src/ingroup_inds.cpp',
'src/ingroup_inds_kernel.cu',
]),
],
cmdclass={'build_ext': BuildExtension},
zip_safe=False,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment