Commit 72c608ce authored by chenshi3's avatar chenshi3
Browse files

Add support for DSVT

parent 02ac3e17
...@@ -23,6 +23,8 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18 ...@@ -23,6 +23,8 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
## Changelog ## Changelog
[2023-06-xx] **NEW:** Added support for [`DSVT`](https://arxiv.org/abs/2301.06051), which achieves state-of-the-art performance on large-scale Waymo Open Dataset with real-time inference speed (27HZ with TensorRt).
[2023-05-13] **NEW:** Added support for the multi-modal 3D object detection models on Nuscenes dataset. [2023-05-13] **NEW:** Added support for the multi-modal 3D object detection models on Nuscenes dataset.
* Support multi-modal Nuscenes detection (See the [GETTING_STARTED.md](docs/GETTING_STARTED.md) to process data). * Support multi-modal Nuscenes detection (See the [GETTING_STARTED.md](docs/GETTING_STARTED.md) to process data).
* Support [TransFusion-Lidar](https://arxiv.org/abs/2203.11496) head, which ahcieves 69.43% NDS on Nuscenes validation dataset. * Support [TransFusion-Lidar](https://arxiv.org/abs/2203.11496) head, which ahcieves 69.43% NDS on Nuscenes validation dataset.
...@@ -192,6 +194,8 @@ Here we also provide the performance of several models trained on the full train ...@@ -192,6 +194,8 @@ Here we also provide the performance of several models trained on the full train
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 | | [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 |
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 | | [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 | | [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 |
| [DSVT-Pillar](tools/cfgs/waymo_models/dsvt_pillar.yaml) | 79.44/78.97 | 71.24/70.81 | 83.00/77.22 | 75.45/69.95 | 76.70/75.70 | 73.83/72.86 |
| [DSVT-Voxel](tools/cfgs/waymo_models/dsvt_voxel.yaml) | 79.77/79.31 | 71.67/71.25 | 83.75/78.92 | 76.21/71.57 | 77.57/76.58 | 74.70/73.73 |
| [PV-RCNN++ (ResNet, 2 frames)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml) | 80.17/79.70 | 72.14/71.70 | 83.48/80.42 | 75.54/72.61 | 74.63/73.75 | 72.35/71.50 | | [PV-RCNN++ (ResNet, 2 frames)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml) | 80.17/79.70 | 72.14/71.70 | 83.48/80.42 | 75.54/72.61 | 74.63/73.75 | 72.35/71.50 |
| [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 | | [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 |
| [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 | | [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 |
...@@ -201,6 +205,7 @@ Here we also provide the performance of several models trained on the full train ...@@ -201,6 +205,7 @@ Here we also provide the performance of several models trained on the full train
We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/),
but you could easily achieve similar performance by training with the default configs. but you could easily achieve similar performance by training with the default configs.
......
...@@ -200,7 +200,11 @@ class WaymoDataset(DatasetTemplate): ...@@ -200,7 +200,11 @@ class WaymoDataset(DatasetTemplate):
points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5] points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5]
if not self.dataset_cfg.get('DISABLE_NLZ_FLAG_ON_POINTS', False): if not self.dataset_cfg.get('DISABLE_NLZ_FLAG_ON_POINTS', False):
points_all = points_all[NLZ_flag == -1] points_all = points_all[NLZ_flag == -1]
if self.dataset_cfg.get('POINTS_TANH_DIM', None) is None:
points_all[:, 3] = np.tanh(points_all[:, 3]) points_all[:, 3] = np.tanh(points_all[:, 3])
else:
for dim_idx in self.dataset_cfg.POINTS_TANH_DIM:
points_all[:, dim_idx] = np.tanh(points_all[:, dim_idx])
return points_all return points_all
@staticmethod @staticmethod
......
from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1 from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1, BaseBEVResBackbone
__all__ = { __all__ = {
'BaseBEVBackbone': BaseBEVBackbone, 'BaseBEVBackbone': BaseBEVBackbone,
'BaseBEVBackboneV1': BaseBEVBackboneV1 'BaseBEVBackboneV1': BaseBEVBackboneV1,
'BaseBEVResBackbone': BaseBEVResBackbone,
} }
...@@ -202,3 +202,150 @@ class BaseBEVBackboneV1(nn.Module): ...@@ -202,3 +202,150 @@ class BaseBEVBackboneV1(nn.Module):
data_dict['spatial_features_2d'] = x data_dict['spatial_features_2d'] = x
return data_dict return data_dict
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
padding: int = 1,
downsample: bool = False,
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=padding, bias=False)
self.bn1 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
self.relu2 = nn.ReLU()
self.downsample = downsample
if self.downsample:
self.downsample_layer = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
)
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
identity = self.downsample_layer(x)
out += identity
out = self.relu2(out)
return out
class BaseBEVResBackbone(nn.Module):
def __init__(self, model_cfg, input_channels):
super().__init__()
self.model_cfg = model_cfg
if self.model_cfg.get('LAYER_NUMS', None) is not None:
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
else:
layer_nums = layer_strides = num_filters = []
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
else:
upsample_strides = num_upsample_filters = []
num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]]
self.blocks = nn.ModuleList()
self.deblocks = nn.ModuleList()
for idx in range(num_levels):
cur_layers = [
# nn.ZeroPad2d(1),
BasicBlock(c_in_list[idx], num_filters[idx], layer_strides[idx], 1, True)
]
for k in range(layer_nums[idx]):
cur_layers.extend([
BasicBlock(num_filters[idx], num_filters[idx])
])
self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
stride = upsample_strides[idx]
if stride >= 1:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
c_in = sum(num_upsample_filters) if len(num_upsample_filters) > 0 else sum(num_filters)
if len(upsample_strides) > num_levels:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
nn.ReLU(),
))
self.num_bev_features = c_in
def forward(self, data_dict):
"""
Args:
data_dict:
spatial_features
Returns:
"""
spatial_features = data_dict['spatial_features']
ups = []
ret_dict = {}
x = spatial_features
for i in range(len(self.blocks)):
x = self.blocks[i](x)
stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)
if len(ups) > 1:
x = torch.cat(ups, dim=1)
elif len(ups) == 1:
x = ups[0]
if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)
data_dict['spatial_features_2d'] = x
return data_dict
from .height_compression import HeightCompression from .height_compression import HeightCompression
from .pointpillar_scatter import PointPillarScatter from .pointpillar_scatter import PointPillarScatter, PointPillarScatter3d
from .conv2d_collapse import Conv2DCollapse from .conv2d_collapse import Conv2DCollapse
__all__ = { __all__ = {
'HeightCompression': HeightCompression, 'HeightCompression': HeightCompression,
'PointPillarScatter': PointPillarScatter, 'PointPillarScatter': PointPillarScatter,
'Conv2DCollapse': Conv2DCollapse 'Conv2DCollapse': Conv2DCollapse,
'PointPillarScatter3d': PointPillarScatter3d,
} }
...@@ -35,3 +35,39 @@ class PointPillarScatter(nn.Module): ...@@ -35,3 +35,39 @@ class PointPillarScatter(nn.Module):
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features * self.nz, self.ny, self.nx) batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features * self.nz, self.ny, self.nx)
batch_dict['spatial_features'] = batch_spatial_features batch_dict['spatial_features'] = batch_spatial_features
return batch_dict return batch_dict
class PointPillarScatter3d(nn.Module):
def __init__(self, model_cfg, grid_size, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.nx, self.ny, self.nz = self.model_cfg.INPUT_SHAPE
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
self.num_bev_features_before_compression = self.model_cfg.NUM_BEV_FEATURES // self.nz
def forward(self, batch_dict, **kwargs):
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
batch_spatial_features = []
batch_size = coords[:, 0].max().int().item() + 1
for batch_idx in range(batch_size):
spatial_feature = torch.zeros(
self.num_bev_features_before_compression,
self.nz * self.nx * self.ny,
dtype=pillar_features.dtype,
device=pillar_features.device)
batch_mask = coords[:, 0] == batch_idx
this_coords = coords[batch_mask, :]
indices = this_coords[:, 1] * self.ny * self.nx + this_coords[:, 2] * self.nx + this_coords[:, 3]
indices = indices.type(torch.long)
pillars = pillar_features[batch_mask, :]
pillars = pillars.t()
spatial_feature[:, indices] = pillars
batch_spatial_features.append(spatial_feature)
batch_spatial_features = torch.stack(batch_spatial_features, 0)
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features_before_compression * self.nz, self.ny, self.nx)
batch_dict['spatial_features'] = batch_spatial_features
return batch_dict
\ No newline at end of file
...@@ -5,6 +5,7 @@ from .spconv_backbone_focal import VoxelBackBone8xFocal ...@@ -5,6 +5,7 @@ from .spconv_backbone_focal import VoxelBackBone8xFocal
from .spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt from .spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D
from .spconv_unet import UNetV2 from .spconv_unet import UNetV2
from .dsvt import DSVT
__all__ = { __all__ = {
'VoxelBackBone8x': VoxelBackBone8x, 'VoxelBackBone8x': VoxelBackBone8x,
...@@ -16,5 +17,6 @@ __all__ = { ...@@ -16,5 +17,6 @@ __all__ = {
'VoxelResBackBone8xVoxelNeXt': VoxelResBackBone8xVoxelNeXt, 'VoxelResBackBone8xVoxelNeXt': VoxelResBackBone8xVoxelNeXt,
'VoxelResBackBone8xVoxelNeXt2D': VoxelResBackBone8xVoxelNeXt2D, 'VoxelResBackBone8xVoxelNeXt2D': VoxelResBackBone8xVoxelNeXt2D,
'PillarBackBone8x': PillarBackBone8x, 'PillarBackBone8x': PillarBackBone8x,
'PillarRes18BackBone8x': PillarRes18BackBone8x 'PillarRes18BackBone8x': PillarRes18BackBone8x,
'DSVT': DSVT,
} }
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from math import ceil
from pcdet.models.model_utils.dsvt_utils import get_window_coors, get_inner_win_inds_cuda, get_pooling_index, get_continous_inds
from pcdet.models.model_utils.dsvt_utils import PositionEmbeddingLearned
class DSVT(nn.Module):
'''Dynamic Sparse Voxel Transformer Backbone.
Args:
INPUT_LAYER: Config of input layer, which converts the output of vfe to dsvt input.
block_name (list[string]): Name of blocks for each stage. Length: stage_num.
set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains
[set_size, block_num], where set_size is the number of voxel in a set and block_num is the
number of blocks for stage i. Length: stage_num.
d_model (list[int]): Number of input channels for each stage. Length: stage_num.
nhead (list[int]): Number of attention heads for each stage. Length: stage_num.
dim_feedforward (list[int]): Dimensions of the feedforward network in set attention for each stage.
Length: stage num.
dropout (float): Drop rate of set attention.
activation (string): Name of activation layer in set attention.
reduction_type (string): Pooling method between stages. One of: "attention", "maxpool", "linear".
output_shape (tuple[int, int]): Shape of output bev feature.
conv_out_channel (int): Number of output channels.
'''
def __init__(self, model_cfg, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.input_layer = DSVTInputLayer(self.model_cfg.INPUT_LAYER)
block_name = self.model_cfg.block_name
set_info = self.model_cfg.set_info
d_model = self.model_cfg.d_model
nhead = self.model_cfg.nhead
dim_feedforward = self.model_cfg.dim_feedforward
dropout = self.model_cfg.dropout
activation = self.model_cfg.activation
self.reduction_type = self.model_cfg.get('reduction_type', 'attention')
# save GPU memory
self.use_torch_ckpt = self.model_cfg.get('USE_CHECKPOINT', False)
# Sparse Regional Attention Blocks
stage_num = len(block_name)
for stage_id in range(stage_num):
num_blocks_this_stage = set_info[stage_id][-1]
dmodel_this_stage = d_model[stage_id]
dfeed_this_stage = dim_feedforward[stage_id]
num_head_this_stage = nhead[stage_id]
block_name_this_stage = block_name[stage_id]
block_module = _get_block_module(block_name_this_stage)
block_list=[]
norm_list=[]
for i in range(num_blocks_this_stage):
block_list.append(
block_module(dmodel_this_stage, num_head_this_stage, dfeed_this_stage,
dropout, activation, batch_first=True)
)
norm_list.append(nn.LayerNorm(dmodel_this_stage))
self.__setattr__(f'stage_{stage_id}', nn.ModuleList(block_list))
self.__setattr__(f'residual_norm_stage_{stage_id}', nn.ModuleList(norm_list))
# apply pooling except the last stage
if stage_id < stage_num-1:
downsample_window = self.model_cfg.INPUT_LAYER.downsample_stride[stage_id]
dmodel_next_stage = d_model[stage_id+1]
pool_volume = torch.IntTensor(downsample_window).prod().item()
if self.reduction_type == 'linear':
cat_feat_dim = dmodel_this_stage * torch.IntTensor(downsample_window).prod().item()
self.__setattr__(f'stage_{stage_id}_reduction', Stage_Reduction_Block(cat_feat_dim, dmodel_next_stage))
elif self.reduction_type == 'maxpool':
self.__setattr__(f'stage_{stage_id}_reduction', torch.nn.MaxPool1d(pool_volume))
elif self.reduction_type == 'attention':
self.__setattr__(f'stage_{stage_id}_reduction', Stage_ReductionAtt_Block(dmodel_this_stage, pool_volume))
else:
raise NotImplementedError
self.num_shifts = [2] * stage_num
self.output_shape = self.model_cfg.output_shape
self.stage_num = stage_num
self.set_info = set_info
self.num_point_features = self.model_cfg.conv_out_channel
self._reset_parameters()
def forward(self, batch_dict):
'''
Args:
bacth_dict (dict):
The dict contains the following keys
- voxel_features (Tensor[float]): Voxel features after VFE. Shape of (N, d_model[0]),
where N is the number of input voxels.
- voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels.
Each row is (batch_id, z, y, x).
- ...
Returns:
bacth_dict (dict):
The dict contains the following keys
- pillar_features (Tensor[float]):
- voxel_coords (Tensor[int]):
- ...
'''
voxel_info = self.input_layer(batch_dict)
voxel_feat = voxel_info['voxel_feats_stage0']
set_voxel_inds_list = [[voxel_info[f'set_voxel_inds_stage{s}_shift{i}'] for i in range(self.num_shifts[s])] for s in range(self.stage_num)]
set_voxel_masks_list = [[voxel_info[f'set_voxel_mask_stage{s}_shift{i}'] for i in range(self.num_shifts[s])] for s in range(self.stage_num)]
pos_embed_list = [[[voxel_info[f'pos_embed_stage{s}_block{b}_shift{i}'] for i in range(self.num_shifts[s])] for b in range(self.set_info[s][1])] for s in range(self.stage_num)]
pooling_mapping_index = [voxel_info[f'pooling_mapping_index_stage{s+1}'] for s in range(self.stage_num-1)]
pooling_index_in_pool = [voxel_info[f'pooling_index_in_pool_stage{s+1}'] for s in range(self.stage_num-1)]
pooling_preholder_feats = [voxel_info[f'pooling_preholder_feats_stage{s+1}'] for s in range(self.stage_num-1)]
output = voxel_feat
block_id = 0
for stage_id in range(self.stage_num):
block_layers = self.__getattr__(f'stage_{stage_id}')
residual_norm_layers = self.__getattr__(f'residual_norm_stage_{stage_id}')
for i in range(len(block_layers)):
block = block_layers[i]
residual = output.clone()
if self.use_torch_ckpt==False:
output = block(output, set_voxel_inds_list[stage_id], set_voxel_masks_list[stage_id], pos_embed_list[stage_id][i], \
block_id=block_id)
else:
output = checkpoint(block, output, set_voxel_inds_list[stage_id], set_voxel_masks_list[stage_id], pos_embed_list[stage_id][i], block_id)
output = residual_norm_layers[i](output + residual)
block_id += 1
if stage_id < self.stage_num - 1:
# pooling
prepool_features = pooling_preholder_feats[stage_id].type_as(output)
pooled_voxel_num = prepool_features.shape[0]
pool_volume = prepool_features.shape[1]
prepool_features[pooling_mapping_index[stage_id], pooling_index_in_pool[stage_id]] = output
prepool_features = prepool_features.view(prepool_features.shape[0], -1)
if self.reduction_type == 'linear':
output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features)
elif self.reduction_type == 'maxpool':
prepool_features = prepool_features.view(pooled_voxel_num, pool_volume, -1).permute(0, 2, 1)
output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features).squeeze(-1)
elif self.reduction_type == 'attention':
prepool_features = prepool_features.view(pooled_voxel_num, pool_volume, -1).permute(0, 2, 1)
key_padding_mask = torch.zeros((pooled_voxel_num, pool_volume)).to(prepool_features.device).int()
output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features, key_padding_mask)
else:
raise NotImplementedError
batch_dict['pillar_features'] = batch_dict['voxel_features'] = output
batch_dict['voxel_coords'] = voxel_info[f'voxel_coors_stage{self.stage_num - 1}']
return batch_dict
def _reset_parameters(self):
for name, p in self.named_parameters():
if p.dim() > 1 and 'scaler' not in name:
nn.init.xavier_uniform_(p)
class DSVTBlock(nn.Module):
''' Consist of two encoder layer, shift and shift back.
'''
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", batch_first=True):
super().__init__()
encoder_1 = DSVT_EncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, batch_first)
encoder_2 = DSVT_EncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, batch_first)
self.encoder_list = nn.ModuleList([encoder_1, encoder_2])
def forward(
self,
src,
set_voxel_inds_list,
set_voxel_masks_list,
pos_embed_list,
block_id,
):
num_shifts = 2
output = src
# TODO: bug to be fixed, mismatch of pos_embed
for i in range(num_shifts):
set_id = i
shift_id = block_id % 2
pos_embed_id = i
set_voxel_inds = set_voxel_inds_list[shift_id][set_id]
set_voxel_masks = set_voxel_masks_list[shift_id][set_id]
pos_embed = pos_embed_list[pos_embed_id]
layer = self.encoder_list[i]
output = layer(output, set_voxel_inds, set_voxel_masks, pos_embed)
return output
class DSVT_EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", batch_first=True, mlp_dropout=0):
super().__init__()
self.win_attn = SetAttention(d_model, nhead, dropout, dim_feedforward, activation, batch_first, mlp_dropout)
self.norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self,src,set_voxel_inds,set_voxel_masks,pos=None):
identity = src
src = self.win_attn(src, pos, set_voxel_masks, set_voxel_inds)
src = src + identity
src = self.norm(src)
return src
class SetAttention(nn.Module):
def __init__(self, d_model, nhead, dropout, dim_feedforward=2048, activation="relu", batch_first=True, mlp_dropout=0):
super().__init__()
self.nhead = nhead
if batch_first:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
else:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(mlp_dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.d_model = d_model
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Identity()
self.dropout2 = nn.Identity()
self.activation = _get_activation_fn(activation)
def forward(self, src, pos=None, key_padding_mask=None, voxel_inds=None):
'''
Args:
src (Tensor[float]): Voxel features with shape (N, C), where N is the number of voxels.
pos (Tensor[float]): Position embedding vectors with shape (N, C).
key_padding_mask (Tensor[bool]): Mask for redundant voxels within set. Shape of (set_num, set_size).
voxel_inds (Tensor[int]): Voxel indexs for each set. Shape of (set_num, set_size).
Returns:
src (Tensor[float]): Voxel features.
'''
set_features = src[voxel_inds]
if pos is not None:
set_pos = pos[voxel_inds]
else:
set_pos = None
if pos is not None:
query = set_features + set_pos
key = set_features + set_pos
value = set_features
if key_padding_mask is not None:
src2 = self.self_attn(query, key, value, key_padding_mask)[0]
else:
src2 = self.self_attn(query, key, value)[0]
# map voxel featurs from set space to voxel space: (set_num, set_size, C) --> (N, C)
flatten_inds = voxel_inds.reshape(-1)
unique_flatten_inds, inverse = torch.unique(flatten_inds, return_inverse=True)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique_flatten_inds.size(0)).scatter_(0, inverse, perm)
src2 = src2.reshape(-1, self.d_model)[perm]
# FFN layer
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class Stage_Reduction_Block(nn.Module):
def __init__(self, input_channel, output_channel):
super().__init__()
self.linear1 = nn.Linear(input_channel, output_channel, bias=False)
self.norm = nn.LayerNorm(output_channel)
def forward(self, x):
src = x
src = self.norm(self.linear1(x))
return src
class Stage_ReductionAtt_Block(nn.Module):
def __init__(self, input_channel, pool_volume):
super().__init__()
self.pool_volume = pool_volume
self.query_func = torch.nn.MaxPool1d(pool_volume)
self.norm = nn.LayerNorm(input_channel)
self.self_attn = nn.MultiheadAttention(input_channel, 8, batch_first=True)
self.pos_embedding = nn.Parameter(torch.randn(pool_volume, input_channel))
nn.init.normal_(self.pos_embedding, std=.01)
def forward(self, x, key_padding_mask):
# x: [voxel_num, c_dim, pool_volume]
src = self.query_func(x).permute(0, 2, 1) # voxel_num, 1, c_dim
key = value = x.permute(0, 2, 1)
key = key + self.pos_embedding.unsqueeze(0).repeat(src.shape[0], 1, 1)
query = src.clone()
output = self.self_attn(query, key, value, key_padding_mask)[0]
src = self.norm(output + src).squeeze(1)
return src
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return torch.nn.functional.relu
if activation == "gelu":
return torch.nn.functional.gelu
if activation == "glu":
return torch.nn.functional.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
def _get_block_module(name):
"""Return an block module given a string"""
if name == "DSVTBlock":
return DSVTBlock
raise RuntimeError(F"This Block not exist.")
class DSVTInputLayer(nn.Module):
'''
This class converts the output of vfe to dsvt input.
We do in this class:
1. Window partition: partition voxels to non-overlapping windows.
2. Set partition: generate non-overlapped and size-equivalent local sets within each window.
3. Pre-compute the downsample infomation between two consecutive stages.
4. Pre-compute the position embedding vectors.
Args:
sparse_shape (tuple[int, int, int]): Shape of input space (xdim, ydim, zdim).
window_shape (list[list[int, int, int]]): Window shapes (winx, winy, winz) in different stages. Length: stage_num.
downsample_stride (list[list[int, int, int]]): Downsample strides between two consecutive stages.
Element i is [ds_x, ds_y, ds_z], which is used between stage_i and stage_{i+1}. Length: stage_num - 1.
d_model (list[int]): Number of input channels for each stage. Length: stage_num.
set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains
[set_size, block_num], where set_size is the number of voxel in a set and block_num is the
number of blocks for stage i. Length: stage_num.
hybrid_factor (list[int, int, int]): Control the window shape in different blocks.
e.g. for block_{0} and block_{1} in stage_0, window shapes are [win_x, win_y, win_z] and
[win_x * h[0], win_y * h[1], win_z * h[2]] respectively.
shift_list (list): Shift window. Length: stage_num.
normalize_pos (bool): Whether to normalize coordinates in position embedding.
'''
def __init__(self, model_cfg):
super().__init__()
self.model_cfg = model_cfg
self.sparse_shape = self.model_cfg.sparse_shape
self.window_shape = self.model_cfg.window_shape
self.downsample_stride = self.model_cfg.downsample_stride
self.d_model = self.model_cfg.d_model
self.set_info = self.model_cfg.set_info
self.stage_num = len(self.d_model)
self.hybrid_factor = self.model_cfg.hybrid_factor
self.window_shape = [[self.window_shape[s_id], [self.window_shape[s_id][coord_id] * self.hybrid_factor[coord_id] \
for coord_id in range(3)]] for s_id in range(self.stage_num)]
self.shift_list = self.model_cfg.shifts_list
self.normalize_pos = self.model_cfg.normalize_pos
self.num_shifts = [2,] * len(self.window_shape)
self.sparse_shape_list = [self.sparse_shape]
# compute sparse shapes for each stage
for ds_stride in self.downsample_stride:
last_sparse_shape = self.sparse_shape_list[-1]
self.sparse_shape_list.append((ceil(last_sparse_shape[0]/ds_stride[0]), ceil(last_sparse_shape[1]/ds_stride[1]), ceil(last_sparse_shape[2]/ds_stride[2])))
# position embedding layers
self.posembed_layers = nn.ModuleList()
for i in range(len(self.set_info)):
input_dim = 3 if self.sparse_shape_list[i][-1] > 1 else 2
stage_posembed_layers = nn.ModuleList()
for j in range(self.set_info[i][1]):
block_posembed_layers = nn.ModuleList()
for s in range(self.num_shifts[i]):
block_posembed_layers.append(PositionEmbeddingLearned(input_dim, self.d_model[i]))
stage_posembed_layers.append(block_posembed_layers)
self.posembed_layers.append(stage_posembed_layers)
def forward(self, batch_dict):
'''
Args:
bacth_dict (dict):
The dict contains the following keys
- voxel_features (Tensor[float]): Voxel features after VFE with shape (N, d_model[0]),
where N is the number of input voxels.
- voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels.
Each row is (batch_id, z, y, x).
- ...
Returns:
voxel_info (dict):
The dict contains the following keys
- voxel_coors_stage{i} (Tensor[int]): Shape of (N_i, 4). N is the number of voxels in stage_i.
Each row is (batch_id, z, y, x).
- set_voxel_inds_stage{i}_shift{j} (Tensor[int]): Set partition index with shape (2, set_num, set_info[i][0]).
2 indicates x-axis partition and y-axis partition.
- set_voxel_mask_stage{i}_shift{i} (Tensor[bool]): Key mask used in set attention with shape (2, set_num, set_info[i][0]).
- pos_embed_stage{i}_block{i}_shift{i} (Tensor[float]): Position embedding vectors with shape (N_i, d_model[i]). N_i is the
number of remain voxels in stage_i;
- pooling_mapping_index_stage{i} (Tensor[int]): Pooling region index used in pooling operation between stage_{i-1} and stage_{i}
with shape (N_{i-1}).
- pooling_index_in_pool_stage{i} (Tensor[int]): Index inner region with shape (N_{i-1}). Combined with pooling_mapping_index_stage{i},
we can map each voxel in satge_{i-1} to pooling_preholder_feats_stage{i}, which are input of downsample operation.
- pooling_preholder_feats_stage{i} (Tensor[int]): Preholder features initial with value 0.
Shape of (N_{i}, downsample_stride[i-1].prob(), d_moel[i-1]), where prob() returns the product of all elements.
- ...
'''
voxel_feats = batch_dict['voxel_features']
voxel_coors = batch_dict['voxel_coords'].long()
voxel_info = {}
voxel_info['voxel_feats_stage0'] = voxel_feats.clone()
voxel_info['voxel_coors_stage0'] = voxel_coors.clone()
for stage_id in range(self.stage_num):
# window partition of corrsponding stage-map
voxel_info = self.window_partition(voxel_info, stage_id)
# generate set id of corrsponding stage-map
voxel_info = self.get_set(voxel_info, stage_id)
for block_id in range(self.set_info[stage_id][1]):
for shift_id in range(self.num_shifts[stage_id]):
voxel_info[f'pos_embed_stage{stage_id}_block{block_id}_shift{shift_id}'] = \
self.get_pos_embed(voxel_info[f'coors_in_win_stage{stage_id}_shift{shift_id}'], stage_id, block_id, shift_id)
# compute pooling information
if stage_id < self.stage_num - 1:
voxel_info = self.subm_pooling(voxel_info, stage_id)
return voxel_info
@torch.no_grad()
def subm_pooling(self, voxel_info, stage_id):
# x,y,z stride
cur_stage_downsample = self.downsample_stride[stage_id]
# batch_win_coords is from 1 of x, y
batch_win_inds, _, index_in_win, batch_win_coors = get_pooling_index(voxel_info[f'voxel_coors_stage{stage_id}'], self.sparse_shape_list[stage_id], cur_stage_downsample)
# compute pooling mapping index
unique_batch_win_inds, contiguous_batch_win_inds = torch.unique(batch_win_inds, return_inverse=True)
voxel_info[f'pooling_mapping_index_stage{stage_id+1}'] = contiguous_batch_win_inds
# generate empty placeholder features
placeholder_prepool_feats = voxel_info[f'voxel_feats_stage0'].new_zeros((len(unique_batch_win_inds),
torch.prod(torch.IntTensor(cur_stage_downsample)).item(), self.d_model[stage_id]))
voxel_info[f'pooling_index_in_pool_stage{stage_id+1}'] = index_in_win
voxel_info[f'pooling_preholder_feats_stage{stage_id+1}'] = placeholder_prepool_feats
# compute pooling coordinates
unique, inverse = unique_batch_win_inds.clone(), contiguous_batch_win_inds.clone()
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
pool_coors = batch_win_coors[perm]
voxel_info[f'voxel_coors_stage{stage_id+1}'] = pool_coors
return voxel_info
def get_set(self, voxel_info, stage_id):
'''
This is one of the core operation of DSVT.
Given voxels' window ids and relative-coords inner window, we partition them into window-bounded and size-equivalent local sets.
To make it clear and easy to follow, we do not use loop to process two shifts.
Args:
voxel_info (dict):
The dict contains the following keys
- batch_win_inds_s{i} (Tensor[float]): Windows indexs of each voxel with shape (N), computed by 'window_partition'.
- coors_in_win_shift{i} (Tensor[int]): Relative-coords inner window of each voxel with shape (N, 3), computed by 'window_partition'.
Each row is (z, y, x).
- ...
Returns:
See from 'forward' function.
'''
batch_win_inds_shift0 = voxel_info[f'batch_win_inds_stage{stage_id}_shift0']
coors_in_win_shift0 = voxel_info[f'coors_in_win_stage{stage_id}_shift0']
set_voxel_inds_shift0 = self.get_set_single_shift(batch_win_inds_shift0, stage_id, shift_id=0, coors_in_win=coors_in_win_shift0)
voxel_info[f'set_voxel_inds_stage{stage_id}_shift0'] = set_voxel_inds_shift0
# compute key masks, voxel duplication must happen continuously
prefix_set_voxel_inds_s0 = torch.roll(set_voxel_inds_shift0.clone(), shifts=1, dims=-1)
prefix_set_voxel_inds_s0[ :, :, 0] = -1
set_voxel_mask_s0 = (set_voxel_inds_shift0 == prefix_set_voxel_inds_s0)
voxel_info[f'set_voxel_mask_stage{stage_id}_shift0'] = set_voxel_mask_s0
batch_win_inds_shift1 = voxel_info[f'batch_win_inds_stage{stage_id}_shift1']
coors_in_win_shift1 = voxel_info[f'coors_in_win_stage{stage_id}_shift1']
set_voxel_inds_shift1 = self.get_set_single_shift(batch_win_inds_shift1, stage_id, shift_id=1, coors_in_win=coors_in_win_shift1)
voxel_info[f'set_voxel_inds_stage{stage_id}_shift1'] = set_voxel_inds_shift1
# compute key masks, voxel duplication must happen continuously
prefix_set_voxel_inds_s1 = torch.roll(set_voxel_inds_shift1.clone(), shifts=1, dims=-1)
prefix_set_voxel_inds_s1[ :, :, 0] = -1
set_voxel_mask_s1 = (set_voxel_inds_shift1 == prefix_set_voxel_inds_s1)
voxel_info[f'set_voxel_mask_stage{stage_id}_shift1'] = set_voxel_mask_s1
return voxel_info
def get_set_single_shift(self, batch_win_inds, stage_id, shift_id=None, coors_in_win=None):
device = batch_win_inds.device
# the number of voxels assigned to a set
voxel_num_set = self.set_info[stage_id][0]
# max number of voxels in a window
max_voxel = self.window_shape[stage_id][shift_id][0] * self.window_shape[stage_id][shift_id][1] * self.window_shape[stage_id][shift_id][2]
# get unique set indexs
contiguous_win_inds = torch.unique(batch_win_inds, return_inverse=True)[1]
voxelnum_per_win = torch.bincount(contiguous_win_inds)
win_num = voxelnum_per_win.shape[0]
setnum_per_win_float = voxelnum_per_win / voxel_num_set
setnum_per_win = torch.ceil(setnum_per_win_float).long()
set_win_inds, set_inds_in_win = get_continous_inds(setnum_per_win)
# compution of Eq.3 in 'DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets' - https://arxiv.org/abs/2301.06051,
# for each window, we can get voxel indexs belong to different sets.
offset_idx = set_inds_in_win[:,None].repeat(1, voxel_num_set) * voxel_num_set
base_idx = torch.arange(0, voxel_num_set, 1, device=device)
base_select_idx = offset_idx + base_idx
base_select_idx = base_select_idx * voxelnum_per_win[set_win_inds][:,None]
base_select_idx = base_select_idx.double() / (setnum_per_win[set_win_inds] * voxel_num_set)[:,None].double()
base_select_idx = torch.floor(base_select_idx)
# obtain unique indexs in whole space
select_idx = base_select_idx
select_idx = select_idx + set_win_inds.view(-1, 1) * max_voxel
# this function will return unordered inner window indexs of each voxel
inner_voxel_inds = get_inner_win_inds_cuda(contiguous_win_inds)
global_voxel_inds = contiguous_win_inds * max_voxel + inner_voxel_inds
_, order1 = torch.sort(global_voxel_inds)
# get y-axis partition results
global_voxel_inds_sorty = contiguous_win_inds * max_voxel + \
coors_in_win[:,1] * self.window_shape[stage_id][shift_id][0] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,2] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,0]
_, order2 = torch.sort(global_voxel_inds_sorty)
inner_voxel_inds_sorty = -torch.ones_like(inner_voxel_inds)
inner_voxel_inds_sorty.scatter_(dim=0, index=order2, src=inner_voxel_inds[order1]) # get y-axis ordered inner window indexs of each voxel
voxel_inds_in_batch_sorty = inner_voxel_inds_sorty + max_voxel * contiguous_win_inds
voxel_inds_padding_sorty = -1 * torch.ones((win_num * max_voxel), dtype=torch.long, device=device)
voxel_inds_padding_sorty[voxel_inds_in_batch_sorty] = torch.arange(0, voxel_inds_in_batch_sorty.shape[0], dtype=torch.long, device=device)
set_voxel_inds_sorty = voxel_inds_padding_sorty[select_idx.long()]
# get x-axis partition results
global_voxel_inds_sortx = contiguous_win_inds * max_voxel + \
coors_in_win[:,2] * self.window_shape[stage_id][shift_id][1] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,1] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,0]
_, order2 = torch.sort(global_voxel_inds_sortx)
inner_voxel_inds_sortx = -torch.ones_like(inner_voxel_inds)
inner_voxel_inds_sortx.scatter_(dim=0,index=order2, src=inner_voxel_inds[order1]) # get x-axis ordered inner window indexs of each voxel
voxel_inds_in_batch_sortx = inner_voxel_inds_sortx + max_voxel * contiguous_win_inds
voxel_inds_padding_sortx = -1 * torch.ones((win_num * max_voxel), dtype=torch.long, device=device)
voxel_inds_padding_sortx[voxel_inds_in_batch_sortx] = torch.arange(0, voxel_inds_in_batch_sortx.shape[0], dtype=torch.long, device=device)
set_voxel_inds_sortx = voxel_inds_padding_sortx[select_idx.long()]
all_set_voxel_inds = torch.stack((set_voxel_inds_sorty, set_voxel_inds_sortx), dim=0)
return all_set_voxel_inds
@torch.no_grad()
def window_partition(self, voxel_info, stage_id):
for i in range(2):
batch_win_inds, coors_in_win = get_window_coors(voxel_info[f'voxel_coors_stage{stage_id}'],
self.sparse_shape_list[stage_id], self.window_shape[stage_id][i], i == 1, self.shift_list[stage_id][i])
voxel_info[f'batch_win_inds_stage{stage_id}_shift{i}'] = batch_win_inds
voxel_info[f'coors_in_win_stage{stage_id}_shift{i}'] = coors_in_win
return voxel_info
def get_pos_embed(self, coors_in_win, stage_id, block_id, shift_id):
'''
Args:
coors_in_win: shape=[N, 3], order: z, y, x
'''
# [N,]
window_shape = self.window_shape[stage_id][shift_id]
embed_layer = self.posembed_layers[stage_id][block_id][shift_id]
if len(window_shape) == 2:
ndim = 2
win_x, win_y = window_shape
win_z = 0
elif window_shape[-1] == 1:
ndim = 2
win_x, win_y = window_shape[:2]
win_z = 0
else:
win_x, win_y, win_z = window_shape
ndim = 3
assert coors_in_win.size(1) == 3
z, y, x = coors_in_win[:, 0] - win_z/2, coors_in_win[:, 1] - win_y/2, coors_in_win[:, 2] - win_x/2
if self.normalize_pos:
x = x / win_x * 2 * 3.1415 #[-pi, pi]
y = y / win_y * 2 * 3.1415 #[-pi, pi]
z = z / win_z * 2 * 3.1415 #[-pi, pi]
if ndim==2:
location = torch.stack((x, y), dim=-1)
else:
location = torch.stack((x, y, z), dim=-1)
pos_embed = embed_layer(location)
return pos_embed
...@@ -2,6 +2,7 @@ from .mean_vfe import MeanVFE ...@@ -2,6 +2,7 @@ from .mean_vfe import MeanVFE
from .pillar_vfe import PillarVFE from .pillar_vfe import PillarVFE
from .dynamic_mean_vfe import DynamicMeanVFE from .dynamic_mean_vfe import DynamicMeanVFE
from .dynamic_pillar_vfe import DynamicPillarVFE, DynamicPillarVFESimple2D from .dynamic_pillar_vfe import DynamicPillarVFE, DynamicPillarVFESimple2D
from .dynamic_voxel_vfe import DynamicVoxelVFE
from .image_vfe import ImageVFE from .image_vfe import ImageVFE
from .vfe_template import VFETemplate from .vfe_template import VFETemplate
...@@ -12,5 +13,6 @@ __all__ = { ...@@ -12,5 +13,6 @@ __all__ = {
'ImageVFE': ImageVFE, 'ImageVFE': ImageVFE,
'DynMeanVFE': DynamicMeanVFE, 'DynMeanVFE': DynamicMeanVFE,
'DynPillarVFE': DynamicPillarVFE, 'DynPillarVFE': DynamicPillarVFE,
'DynamicPillarVFESimple2D': DynamicPillarVFESimple2D 'DynamicPillarVFESimple2D': DynamicPillarVFESimple2D,
'DynamicVoxelVFE': DynamicVoxelVFE,
} }
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import torch_scatter
except Exception as e:
# Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
pass
from .vfe_template import VFETemplate
from .dynamic_pillar_vfe import PFNLayerV2
class DynamicVoxelVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
super().__init__(model_cfg=model_cfg)
self.use_norm = self.model_cfg.USE_NORM
self.with_distance = self.model_cfg.WITH_DISTANCE
self.use_absolute_xyz = self.model_cfg.USE_ABSLOTE_XYZ
num_point_features += 6 if self.use_absolute_xyz else 3
if self.with_distance:
num_point_features += 1
self.num_filters = self.model_cfg.NUM_FILTERS
assert len(self.num_filters) > 0
num_filters = [num_point_features] + list(self.num_filters)
pfn_layers = []
for i in range(len(num_filters) - 1):
in_filters = num_filters[i]
out_filters = num_filters[i + 1]
pfn_layers.append(
PFNLayerV2(in_filters, out_filters, self.use_norm, last_layer=(i >= len(num_filters) - 2))
)
self.pfn_layers = nn.ModuleList(pfn_layers)
self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2]
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
self.scale_yz = grid_size[1] * grid_size[2]
self.scale_z = grid_size[2]
self.grid_size = torch.tensor(grid_size).cuda()
self.voxel_size = torch.tensor(voxel_size).cuda()
self.point_cloud_range = torch.tensor(point_cloud_range).cuda()
def get_output_feature_dim(self):
return self.num_filters[-1]
def forward(self, batch_dict, **kwargs):
points = batch_dict['points'] # (batch_idx, x, y, z, i, e)
points_coords = torch.floor((points[:, [1,2,3]] - self.point_cloud_range[[0,1,2]]) / self.voxel_size[[0,1,2]]).int()
mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0,1,2]])).all(dim=1)
points = points[mask]
points_coords = points_coords[mask]
points_xyz = points[:, [1, 2, 3]].contiguous()
merge_coords = points[:, 0].int() * self.scale_xyz + \
points_coords[:, 0] * self.scale_yz + \
points_coords[:, 1] * self.scale_z + \
points_coords[:, 2]
unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True, dim=0)
points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
f_cluster = points_xyz - points_mean[unq_inv, :]
f_center = torch.zeros_like(points_xyz)
f_center[:, 0] = points_xyz[:, 0] - (points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset)
f_center[:, 1] = points_xyz[:, 1] - (points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset)
# f_center[:, 2] = points_xyz[:, 2] - self.z_offset
f_center[:, 2] = points_xyz[:, 2] - (points_coords[:, 2].to(points_xyz.dtype) * self.voxel_z + self.z_offset)
if self.use_absolute_xyz:
features = [points[:, 1:], f_cluster, f_center]
else:
features = [points[:, 4:], f_cluster, f_center]
if self.with_distance:
points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True)
features.append(points_dist)
features = torch.cat(features, dim=-1)
for pfn in self.pfn_layers:
features = pfn(features, unq_inv)
# generate voxel coordinates
unq_coords = unq_coords.int()
voxel_coords = torch.stack((unq_coords // self.scale_xyz,
(unq_coords % self.scale_xyz) // self.scale_yz,
(unq_coords % self.scale_yz) // self.scale_z,
unq_coords % self.scale_z), dim=1)
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]
batch_dict['pillar_features'] = batch_dict['voxel_features'] = features
batch_dict['voxel_coords'] = voxel_coords
return batch_dict
...@@ -6,10 +6,11 @@ from torch.nn.init import kaiming_normal_ ...@@ -6,10 +6,11 @@ from torch.nn.init import kaiming_normal_
from ..model_utils import model_nms_utils from ..model_utils import model_nms_utils
from ..model_utils import centernet_utils from ..model_utils import centernet_utils
from ...utils import loss_utils from ...utils import loss_utils
from functools import partial
class SeparateHead(nn.Module): class SeparateHead(nn.Module):
def __init__(self, input_channels, sep_head_dict, init_bias=-2.19, use_bias=False): def __init__(self, input_channels, sep_head_dict, init_bias=-2.19, use_bias=False, norm_func=None):
super().__init__() super().__init__()
self.sep_head_dict = sep_head_dict self.sep_head_dict = sep_head_dict
...@@ -21,7 +22,7 @@ class SeparateHead(nn.Module): ...@@ -21,7 +22,7 @@ class SeparateHead(nn.Module):
for k in range(num_conv - 1): for k in range(num_conv - 1):
fc_list.append(nn.Sequential( fc_list.append(nn.Sequential(
nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.BatchNorm2d(input_channels), nn.BatchNorm2d(input_channels) if norm_func is None else norm_func(input_channels),
nn.ReLU() nn.ReLU()
)) ))
fc_list.append(nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=True)) fc_list.append(nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=True))
...@@ -70,12 +71,13 @@ class CenterHead(nn.Module): ...@@ -70,12 +71,13 @@ class CenterHead(nn.Module):
total_classes = sum([len(x) for x in self.class_names_each_head]) total_classes = sum([len(x) for x in self.class_names_each_head])
assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}' assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}'
norm_func = partial(nn.BatchNorm2d, eps=self.model_cfg.get('BN_EPS', 1e-5), momentum=self.model_cfg.get('BN_MOM', 0.1))
self.shared_conv = nn.Sequential( self.shared_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
input_channels, self.model_cfg.SHARED_CONV_CHANNEL, 3, stride=1, padding=1, input_channels, self.model_cfg.SHARED_CONV_CHANNEL, 3, stride=1, padding=1,
bias=self.model_cfg.get('USE_BIAS_BEFORE_NORM', False) bias=self.model_cfg.get('USE_BIAS_BEFORE_NORM', False)
), ),
nn.BatchNorm2d(self.model_cfg.SHARED_CONV_CHANNEL), norm_func(self.model_cfg.SHARED_CONV_CHANNEL),
nn.ReLU(), nn.ReLU(),
) )
...@@ -89,7 +91,8 @@ class CenterHead(nn.Module): ...@@ -89,7 +91,8 @@ class CenterHead(nn.Module):
input_channels=self.model_cfg.SHARED_CONV_CHANNEL, input_channels=self.model_cfg.SHARED_CONV_CHANNEL,
sep_head_dict=cur_head_dict, sep_head_dict=cur_head_dict,
init_bias=-2.19, init_bias=-2.19,
use_bias=self.model_cfg.get('USE_BIAS_BEFORE_NORM', False) use_bias=self.model_cfg.get('USE_BIAS_BEFORE_NORM', False),
norm_func=norm_func
) )
) )
self.predict_boxes_when_training = predict_boxes_when_training self.predict_boxes_when_training = predict_boxes_when_training
...@@ -116,6 +119,8 @@ class CenterHead(nn.Module): ...@@ -116,6 +119,8 @@ class CenterHead(nn.Module):
ret_boxes = gt_boxes.new_zeros((num_max_objs, gt_boxes.shape[-1] - 1 + 1)) ret_boxes = gt_boxes.new_zeros((num_max_objs, gt_boxes.shape[-1] - 1 + 1))
inds = gt_boxes.new_zeros(num_max_objs).long() inds = gt_boxes.new_zeros(num_max_objs).long()
mask = gt_boxes.new_zeros(num_max_objs).long() mask = gt_boxes.new_zeros(num_max_objs).long()
ret_boxes_src = gt_boxes.new_zeros(num_max_objs, gt_boxes.shape[-1])
ret_boxes_src[:gt_boxes.shape[0]] = gt_boxes
x, y, z = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2] x, y, z = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2]
coord_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / feature_map_stride coord_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / feature_map_stride
...@@ -154,7 +159,7 @@ class CenterHead(nn.Module): ...@@ -154,7 +159,7 @@ class CenterHead(nn.Module):
if gt_boxes.shape[1] > 8: if gt_boxes.shape[1] > 8:
ret_boxes[k, 8:] = gt_boxes[k, 7:-1] ret_boxes[k, 8:] = gt_boxes[k, 7:-1]
return heatmap, ret_boxes, inds, mask return heatmap, ret_boxes, inds, mask, ret_boxes_src
def assign_targets(self, gt_boxes, feature_map_size=None, **kwargs): def assign_targets(self, gt_boxes, feature_map_size=None, **kwargs):
""" """
...@@ -176,12 +181,13 @@ class CenterHead(nn.Module): ...@@ -176,12 +181,13 @@ class CenterHead(nn.Module):
'target_boxes': [], 'target_boxes': [],
'inds': [], 'inds': [],
'masks': [], 'masks': [],
'heatmap_masks': [] 'heatmap_masks': [],
'target_boxes_src': [],
} }
all_names = np.array(['bg', *self.class_names]) all_names = np.array(['bg', *self.class_names])
for idx, cur_class_names in enumerate(self.class_names_each_head): for idx, cur_class_names in enumerate(self.class_names_each_head):
heatmap_list, target_boxes_list, inds_list, masks_list = [], [], [], [] heatmap_list, target_boxes_list, inds_list, masks_list, target_boxes_src_list = [], [], [], [], []
for bs_idx in range(batch_size): for bs_idx in range(batch_size):
cur_gt_boxes = gt_boxes[bs_idx] cur_gt_boxes = gt_boxes[bs_idx]
gt_class_names = all_names[cur_gt_boxes[:, -1].cpu().long().numpy()] gt_class_names = all_names[cur_gt_boxes[:, -1].cpu().long().numpy()]
...@@ -200,7 +206,7 @@ class CenterHead(nn.Module): ...@@ -200,7 +206,7 @@ class CenterHead(nn.Module):
else: else:
gt_boxes_single_head = torch.cat(gt_boxes_single_head, dim=0) gt_boxes_single_head = torch.cat(gt_boxes_single_head, dim=0)
heatmap, ret_boxes, inds, mask = self.assign_target_of_single_head( heatmap, ret_boxes, inds, mask, ret_boxes_src = self.assign_target_of_single_head(
num_classes=len(cur_class_names), gt_boxes=gt_boxes_single_head.cpu(), num_classes=len(cur_class_names), gt_boxes=gt_boxes_single_head.cpu(),
feature_map_size=feature_map_size, feature_map_stride=target_assigner_cfg.FEATURE_MAP_STRIDE, feature_map_size=feature_map_size, feature_map_stride=target_assigner_cfg.FEATURE_MAP_STRIDE,
num_max_objs=target_assigner_cfg.NUM_MAX_OBJS, num_max_objs=target_assigner_cfg.NUM_MAX_OBJS,
...@@ -211,11 +217,13 @@ class CenterHead(nn.Module): ...@@ -211,11 +217,13 @@ class CenterHead(nn.Module):
target_boxes_list.append(ret_boxes.to(gt_boxes_single_head.device)) target_boxes_list.append(ret_boxes.to(gt_boxes_single_head.device))
inds_list.append(inds.to(gt_boxes_single_head.device)) inds_list.append(inds.to(gt_boxes_single_head.device))
masks_list.append(mask.to(gt_boxes_single_head.device)) masks_list.append(mask.to(gt_boxes_single_head.device))
target_boxes_src_list.append(ret_boxes_src.to(gt_boxes_single_head.device))
ret_dict['heatmaps'].append(torch.stack(heatmap_list, dim=0)) ret_dict['heatmaps'].append(torch.stack(heatmap_list, dim=0))
ret_dict['target_boxes'].append(torch.stack(target_boxes_list, dim=0)) ret_dict['target_boxes'].append(torch.stack(target_boxes_list, dim=0))
ret_dict['inds'].append(torch.stack(inds_list, dim=0)) ret_dict['inds'].append(torch.stack(inds_list, dim=0))
ret_dict['masks'].append(torch.stack(masks_list, dim=0)) ret_dict['masks'].append(torch.stack(masks_list, dim=0))
ret_dict['target_boxes_src'].append(torch.stack(target_boxes_src_list, dim=0))
return ret_dict return ret_dict
def sigmoid(self, x): def sigmoid(self, x):
...@@ -247,6 +255,42 @@ class CenterHead(nn.Module): ...@@ -247,6 +255,42 @@ class CenterHead(nn.Module):
tb_dict['hm_loss_head_%d' % idx] = hm_loss.item() tb_dict['hm_loss_head_%d' % idx] = hm_loss.item()
tb_dict['loc_loss_head_%d' % idx] = loc_loss.item() tb_dict['loc_loss_head_%d' % idx] = loc_loss.item()
if 'iou' in pred_dict or self.model_cfg.get('IOU_REG_LOSS', False):
batch_box_preds = centernet_utils.decode_bbox_from_pred_dicts(
pred_dict=pred_dict,
point_cloud_range=self.point_cloud_range, voxel_size=self.voxel_size,
feature_map_stride=self.feature_map_stride
) # (B, H, W, 7 or 9)
if 'iou' in pred_dict:
batch_box_preds_for_iou = batch_box_preds.permute(0, 3, 1, 2) # (B, 7 or 9, H, W)
iou_loss = loss_utils.calculate_iou_loss_centerhead(
iou_preds=pred_dict['iou'],
batch_box_preds=batch_box_preds_for_iou.clone().detach(),
mask=target_dicts['masks'][idx],
ind=target_dicts['inds'][idx], gt_boxes=target_dicts['target_boxes_src'][idx]
)
loss += iou_loss
tb_dict['iou_loss_head_%d' % idx] = iou_loss.item()
if self.model_cfg.get('IOU_REG_LOSS', False):
iou_reg_loss = loss_utils.calculate_iou_reg_loss_centerhead(
batch_box_preds=batch_box_preds_for_iou,
mask=target_dicts['masks'][idx],
ind=target_dicts['inds'][idx], gt_boxes=target_dicts['target_boxes_src'][idx]
)
if target_dicts['masks'][idx].sum().item() != 0:
iou_reg_loss = iou_reg_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
loss += iou_reg_loss
tb_dict['iou_reg_loss_head_%d' % idx] = iou_reg_loss.item()
else:
loss += (batch_box_preds_for_iou * 0.).sum()
tb_dict['iou_reg_loss_head_%d' % idx] = (batch_box_preds_for_iou * 0.).sum()
tb_dict['rpn_loss'] = loss.item() tb_dict['rpn_loss'] = loss.item()
return loss, tb_dict return loss, tb_dict
...@@ -268,9 +312,11 @@ class CenterHead(nn.Module): ...@@ -268,9 +312,11 @@ class CenterHead(nn.Module):
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1) batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1)
batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None
batch_iou = (pred_dict['iou'] + 1) * 0.5 if 'iou' in pred_dict else None
final_pred_dicts = centernet_utils.decode_bbox_from_heatmap( final_pred_dicts = centernet_utils.decode_bbox_from_heatmap(
heatmap=batch_hm, rot_cos=batch_rot_cos, rot_sin=batch_rot_sin, heatmap=batch_hm, rot_cos=batch_rot_cos, rot_sin=batch_rot_sin,
center=batch_center, center_z=batch_center_z, dim=batch_dim, vel=batch_vel, center=batch_center, center_z=batch_center_z, dim=batch_dim, vel=batch_vel, iou=batch_iou,
point_cloud_range=self.point_cloud_range, voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range, voxel_size=self.voxel_size,
feature_map_stride=self.feature_map_stride, feature_map_stride=self.feature_map_stride,
K=post_process_cfg.MAX_OBJ_PER_SAMPLE, K=post_process_cfg.MAX_OBJ_PER_SAMPLE,
...@@ -281,13 +327,28 @@ class CenterHead(nn.Module): ...@@ -281,13 +327,28 @@ class CenterHead(nn.Module):
for k, final_dict in enumerate(final_pred_dicts): for k, final_dict in enumerate(final_pred_dicts):
final_dict['pred_labels'] = self.class_id_mapping_each_head[idx][final_dict['pred_labels'].long()] final_dict['pred_labels'] = self.class_id_mapping_each_head[idx][final_dict['pred_labels'].long()]
if post_process_cfg.NMS_CONFIG.NMS_TYPE != 'circle_nms':
if post_process_cfg.get('USE_IOU_TO_RECTIFY_SCORE', False) and 'pred_iou' in final_dict:
pred_iou = torch.clamp(final_dict['pred_iou'], min=0, max=1.0)
IOU_RECTIFIER = final_dict['pred_scores'].new_tensor(post_process_cfg.IOU_RECTIFIER)
final_dict['pred_scores'] = torch.pow(final_dict['pred_scores'], 1 - IOU_RECTIFIER[final_dict['pred_labels']]) * torch.pow(pred_iou, IOU_RECTIFIER[final_dict['pred_labels']])
if post_process_cfg.NMS_CONFIG.NMS_TYPE not in ['circle_nms', 'class_specific_nms']:
selected, selected_scores = model_nms_utils.class_agnostic_nms( selected, selected_scores = model_nms_utils.class_agnostic_nms(
box_scores=final_dict['pred_scores'], box_preds=final_dict['pred_boxes'], box_scores=final_dict['pred_scores'], box_preds=final_dict['pred_boxes'],
nms_config=post_process_cfg.NMS_CONFIG, nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=None score_thresh=None
) )
elif post_process_cfg.NMS_CONFIG.NMS_TYPE == 'class_specific_nms':
selected, selected_scores = model_nms_utils.class_specific_nms(
box_scores=final_dict['pred_scores'], box_preds=final_dict['pred_boxes'],
box_labels=final_dict['pred_labels'], nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.NMS_CONFIG.get('SCORE_THRESH', None)
)
elif post_process_cfg.NMS_CONFIG.NMS_TYPE == 'circle_nms':
raise NotImplementedError
final_dict['pred_boxes'] = final_dict['pred_boxes'][selected] final_dict['pred_boxes'] = final_dict['pred_boxes'][selected]
final_dict['pred_scores'] = selected_scores final_dict['pred_scores'] = selected_scores
final_dict['pred_labels'] = final_dict['pred_labels'][selected] final_dict['pred_labels'] = final_dict['pred_labels'][selected]
......
...@@ -171,7 +171,7 @@ def _topk(scores, K=40): ...@@ -171,7 +171,7 @@ def _topk(scores, K=40):
def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim, def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim,
point_cloud_range=None, voxel_size=None, feature_map_stride=None, vel=None, K=100, point_cloud_range=None, voxel_size=None, feature_map_stride=None, vel=None, iou=None, K=100,
circle_nms=False, score_thresh=None, post_center_limit_range=None): circle_nms=False, score_thresh=None, post_center_limit_range=None):
batch_size, num_class, _, _ = heatmap.size() batch_size, num_class, _, _ = heatmap.size()
...@@ -199,6 +199,9 @@ def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim, ...@@ -199,6 +199,9 @@ def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim,
vel = _transpose_and_gather_feat(vel, inds).view(batch_size, K, 2) vel = _transpose_and_gather_feat(vel, inds).view(batch_size, K, 2)
box_part_list.append(vel) box_part_list.append(vel)
if iou is not None:
iou = _transpose_and_gather_feat(iou, inds).view(batch_size, K)
final_box_preds = torch.cat((box_part_list), dim=-1) final_box_preds = torch.cat((box_part_list), dim=-1)
final_scores = scores.view(batch_size, K) final_scores = scores.view(batch_size, K)
final_class_ids = class_ids.view(batch_size, K) final_class_ids = class_ids.view(batch_size, K)
...@@ -232,6 +235,9 @@ def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim, ...@@ -232,6 +235,9 @@ def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim,
'pred_scores': cur_scores, 'pred_scores': cur_scores,
'pred_labels': cur_labels 'pred_labels': cur_labels
}) })
if iou is not None:
ret_pred_dicts[-1]['pred_iou'] = iou[k, cur_mask]
return ret_pred_dicts return ret_pred_dicts
def _topk_1d(scores, batch_size, batch_idx, obj, K=40, nuscenes=False): def _topk_1d(scores, batch_size, batch_idx, obj, K=40, nuscenes=False):
...@@ -346,3 +352,34 @@ def decode_bbox_from_voxels_nuscenes(batch_size, indices, obj, rot_cos, rot_sin, ...@@ -346,3 +352,34 @@ def decode_bbox_from_voxels_nuscenes(batch_size, indices, obj, rot_cos, rot_sin,
'add_features': cur_add_features, 'add_features': cur_add_features,
}) })
return ret_pred_dicts return ret_pred_dicts
def decode_bbox_from_pred_dicts(pred_dict, point_cloud_range=None, voxel_size=None, feature_map_stride=None):
batch_size, _, H, W = pred_dict['center'].shape
batch_center = pred_dict['center'].permute(0, 2, 3, 1).contiguous().view(batch_size, H*W, 2) # (B, H, W, 2)
batch_center_z = pred_dict['center_z'].permute(0, 2, 3, 1).contiguous().view(batch_size, H*W, 1) # (B, H, W, 1)
batch_dim = pred_dict['dim'].exp().permute(0, 2, 3, 1).contiguous().view(batch_size, H*W, 3) # (B, H, W, 3)
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1).permute(0, 2, 3, 1).contiguous().view(batch_size, H*W, 1) # (B, H, W, 1)
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1).permute(0, 2, 3, 1).contiguous().view(batch_size, H*W, 1) # (B, H, W, 1)
batch_vel = pred_dict['vel'].permute(0, 2, 3, 1).contiguous().view(batch_size, H*W, 2) if 'vel' in pred_dict.keys() else None
angle = torch.atan2(batch_rot_sin, batch_rot_cos) # (B, H*W, 1)
ys, xs = torch.meshgrid([torch.arange(0, H, device=batch_center.device, dtype=batch_center.dtype),
torch.arange(0, W, device=batch_center.device, dtype=batch_center.dtype)])
ys = ys.view(1, H, W).repeat(batch_size, 1, 1)
xs = xs.view(1, H, W).repeat(batch_size, 1, 1)
xs = xs.view(batch_size, -1, 1) + batch_center[:, :, 0:1]
ys = ys.view(batch_size, -1, 1) + batch_center[:, :, 1:2]
xs = xs * feature_map_stride * voxel_size[0] + point_cloud_range[0]
ys = ys * feature_map_stride * voxel_size[1] + point_cloud_range[1]
box_part_list = [xs, ys, batch_center_z, batch_dim, angle]
if batch_vel is not None:
box_part_list.append(batch_vel)
box_preds = torch.cat((box_part_list), dim=-1).view(batch_size, H, W, -1)
return box_preds
import torch
import torch.nn as nn
import numpy as np
from pcdet.ops.ingroup_inds.ingroup_inds_op import ingroup_inds
get_inner_win_inds_cuda = ingroup_inds
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, input_channel, num_pos_feats):
super().__init__()
self.position_embedding_head = nn.Sequential(
nn.Linear(input_channel, num_pos_feats),
nn.BatchNorm1d(num_pos_feats),
nn.ReLU(inplace=True),
nn.Linear(num_pos_feats, num_pos_feats))
def forward(self, xyz):
position_embedding = self.position_embedding_head(xyz)
return position_embedding
@torch.no_grad()
def get_window_coors(coors, sparse_shape, window_shape, do_shift, shift_list=None, return_win_coors=False):
if len(window_shape) == 2:
win_shape_x, win_shape_y = window_shape
win_shape_z = sparse_shape[-1]
else:
win_shape_x, win_shape_y, win_shape_z = window_shape
sparse_shape_x, sparse_shape_y, sparse_shape_z = sparse_shape
assert sparse_shape_z < sparse_shape_x, 'Usually holds... in case of wrong order'
max_num_win_x = int(np.ceil((sparse_shape_x / win_shape_x)) + 1) # plus one here to meet the needs of shift.
max_num_win_y = int(np.ceil((sparse_shape_y / win_shape_y)) + 1) # plus one here to meet the needs of shift.
max_num_win_z = int(np.ceil((sparse_shape_z / win_shape_z)) + 1) # plus one here to meet the needs of shift.
max_num_win_per_sample = max_num_win_x * max_num_win_y * max_num_win_z
if do_shift:
if shift_list is not None:
shift_x, shift_y, shift_z = shift_list[0], shift_list[1], shift_list[2]
else:
shift_x, shift_y, shift_z = win_shape_x // 2, win_shape_y // 2, win_shape_z // 2
else:
if shift_list is not None:
shift_x, shift_y, shift_z = shift_list[0], shift_list[1], shift_list[2]
else:
shift_x, shift_y, shift_z = win_shape_x, win_shape_y, win_shape_z
# compatibility between 2D window and 3D window
if sparse_shape_z == win_shape_z:
shift_z = 0
shifted_coors_x = coors[:, 3] + shift_x
shifted_coors_y = coors[:, 2] + shift_y
shifted_coors_z = coors[:, 1] + shift_z
win_coors_x = shifted_coors_x // win_shape_x
win_coors_y = shifted_coors_y // win_shape_y
win_coors_z = shifted_coors_z // win_shape_z
if len(window_shape) == 2:
assert (win_coors_z == 0).all()
batch_win_inds = coors[:, 0] * max_num_win_per_sample + \
win_coors_x * max_num_win_y * max_num_win_z + \
win_coors_y * max_num_win_z + \
win_coors_z
coors_in_win_x = shifted_coors_x % win_shape_x
coors_in_win_y = shifted_coors_y % win_shape_y
coors_in_win_z = shifted_coors_z % win_shape_z
coors_in_win = torch.stack([coors_in_win_z, coors_in_win_y, coors_in_win_x], dim=-1)
# coors_in_win = torch.stack([coors_in_win_x, coors_in_win_y], dim=-1)
if return_win_coors:
batch_win_coords = torch.stack([win_coors_z, win_coors_y, win_coors_x], dim=-1)
return batch_win_inds, coors_in_win, batch_win_coords
return batch_win_inds, coors_in_win
def get_pooling_index(coors, sparse_shape, window_shape):
win_shape_x, win_shape_y, win_shape_z = window_shape
sparse_shape_x, sparse_shape_y, sparse_shape_z = sparse_shape
max_num_win_x = int(np.ceil((sparse_shape_x / win_shape_x)))
max_num_win_y = int(np.ceil((sparse_shape_y / win_shape_y)))
max_num_win_z = int(np.ceil((sparse_shape_z / win_shape_z)))
max_num_win_per_sample = max_num_win_x * max_num_win_y * max_num_win_z
coors_x = coors[:, 3]
coors_y = coors[:, 2]
coors_z = coors[:, 1]
win_coors_x = coors_x // win_shape_x
win_coors_y = coors_y // win_shape_y
win_coors_z = coors_z // win_shape_z
batch_win_inds = coors[:, 0] * max_num_win_per_sample + \
win_coors_x * max_num_win_y * max_num_win_z + \
win_coors_y * max_num_win_z + \
win_coors_z
coors_in_win_x = coors_x % win_shape_x
coors_in_win_y = coors_y % win_shape_y
coors_in_win_z = coors_z % win_shape_z
coors_in_win = torch.stack([coors_in_win_z, coors_in_win_y, coors_in_win_x], dim=-1)
index_in_win = coors_in_win_x * win_shape_y * win_shape_z + \
coors_in_win_y * win_shape_z + \
coors_in_win_z
batch_win_coords = torch.stack([coors[:, 0], win_coors_z, win_coors_y, win_coors_x], dim=-1)
return batch_win_inds, coors_in_win, index_in_win, batch_win_coords
def get_continous_inds(setnum_per_win):
'''
Args:
setnum_per_win (Tensor[int]): Number of sets assigned to each window with shape (win_num).
Returns:
set_win_inds (Tensor[int]): Window indexs of each set with shape (set_num).
set_inds_in_win (Tensor[int]): Set indexs inner window with shape (set_num).
Examples:
setnum_per_win = torch.tensor([1, 2, 1, 3])
set_inds_in_win = get_continous_inds(setnum_per_win)
# we can get: set_inds_in_win = tensor([0, 0, 1, 0, 0, 1, 2])
'''
set_num = setnum_per_win.sum().item() # set_num = 7
setnum_per_win_cumsum = torch.cumsum(setnum_per_win, dim=0)[:-1] # [1, 3, 4]
set_win_inds = torch.full((set_num,), 0, device=setnum_per_win.device)
set_win_inds[setnum_per_win_cumsum] = 1 # [0, 1, 0, 1, 1, 0, 0]
set_win_inds = torch.cumsum(set_win_inds, dim=0) # [0, 1, 1, 2, 3, 3, 3]
roll_set_win_inds_left = torch.roll(set_win_inds, -1) # [1, 1, 2, 3, 3, 3, 0]
diff = set_win_inds - roll_set_win_inds_left # [-1, 0, -1, -1, 0, 0, 3]
end_pos_mask = diff != 0
template = torch.ones_like(set_win_inds)
template[end_pos_mask] = (setnum_per_win - 1) * -1 # [ 0, 1, -1, 0, 1, 1, -2]
set_inds_in_win = torch.cumsum(template,dim=0) # [0, 1, 0, 0, 1, 2, 0]
set_inds_in_win[end_pos_mask] = setnum_per_win # [1, 1, 2, 1, 1, 2, 3]
set_inds_in_win = set_inds_in_win - 1 # [0, 0, 1, 0, 0, 1, 2]
return set_win_inds, set_inds_in_win
\ No newline at end of file
...@@ -64,3 +64,44 @@ def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None): ...@@ -64,3 +64,44 @@ def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None):
pred_boxes = torch.cat(pred_boxes, dim=0) pred_boxes = torch.cat(pred_boxes, dim=0)
return pred_scores, pred_labels, pred_boxes return pred_scores, pred_labels, pred_boxes
def class_specific_nms(box_scores, box_preds, box_labels, nms_config, score_thresh=None):
"""
Args:
cls_scores: (N,)
box_preds: (N, 7 + C)
box_labels: (N,)
nms_config:
Returns:
"""
selected = []
for k in range(len(nms_config.NMS_THRESH)):
curr_mask = box_labels == k
if score_thresh is not None and isinstance(score_thresh, float):
curr_mask *= (box_scores > score_thresh)
elif score_thresh is not None and isinstance(score_thresh, list):
curr_mask *= (box_scores > score_thresh[k])
curr_idx = torch.nonzero(curr_mask)[:, 0]
curr_box_scores = box_scores[curr_mask]
cur_box_preds = box_preds[curr_mask]
if curr_box_scores.shape[0] > 0:
curr_box_scores_nms = curr_box_scores
curr_boxes_for_nms = cur_box_preds
keep_idx, _ = getattr(iou3d_nms_utils, 'nms_gpu')(
curr_boxes_for_nms, curr_box_scores_nms,
thresh=nms_config.NMS_THRESH[k],
pre_maxsize=nms_config.NMS_PRE_MAXSIZE[k],
post_max_size=nms_config.NMS_POST_MAXSIZE[k]
)
curr_selected = curr_idx[keep_idx]
selected.append(curr_selected)
if len(selected) != 0:
selected = torch.cat(selected)
return selected, box_scores[selected]
import torch
try:
from . import ingroup_inds_cuda
# import ingroup_indices
except ImportError:
ingroup_indices = None
print('Can not import ingroup indices')
ingroup_indices = ingroup_inds_cuda
from torch.autograd import Function
class IngroupIndicesFunction(Function):
@staticmethod
def forward(ctx, group_inds):
out_inds = torch.zeros_like(group_inds) - 1
ingroup_indices.forward(group_inds, out_inds)
ctx.mark_non_differentiable(out_inds)
return out_inds
@staticmethod
def backward(ctx, g):
return None
ingroup_inds = IngroupIndicesFunction.apply
\ No newline at end of file
#pragma once
#include <stdio.h>
#define CHECK_CALL(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
\ No newline at end of file
#include <assert.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void ingroup_inds_launcher(
const long *group_inds_data,
long *out_inds_data,
int N,
int max_group_id
);
void ingroup_inds_gpu(
at::Tensor group_inds,
at::Tensor out_inds
);
void ingroup_inds_gpu(
at::Tensor group_inds,
at::Tensor out_inds
) {
CHECK_INPUT(group_inds);
CHECK_INPUT(out_inds);
int N = group_inds.size(0);
int max_group_id = group_inds.max().item().toLong();
long *group_inds_data = group_inds.data_ptr<long>();
long *out_inds_data = out_inds.data_ptr<long>();
ingroup_inds_launcher(
group_inds_data,
out_inds_data,
N,
max_group_id
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &ingroup_inds_gpu, "cuda version of get_inner_win_inds of SST");
}
\ No newline at end of file
#include <assert.h>
#include <vector>
#include <math.h>
#include <stdio.h>
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <torch/types.h>
#include "cuda_fp16.h"
// #include "error.cuh"
#define CHECK_CALL(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line: %d\n", __LINE__); \
printf(" Error code: %d\n", error_code); \
printf(" Error text: %s\n", \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG
// #define ASSERTION
__global__ void ingroup_inds_kernel(
const long *group_inds,
long *out_inds,
int *ingroup_counter,
int N
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
long this_group_id = group_inds[idx];
int cnt = atomicAdd(&ingroup_counter[this_group_id], 1);
out_inds[idx] = cnt;
}
void ingroup_inds_launcher(
const long *group_inds,
long *out_inds,
int N,
int max_group_id
) {
int *ingroup_counter = NULL;
CHECK_CALL(cudaMalloc(&ingroup_counter, (max_group_id + 1) * sizeof(int)));
CHECK_CALL(cudaMemset(ingroup_counter, 0, (max_group_id + 1) * sizeof(int)));
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
ingroup_inds_kernel<<<blocks, threads>>>(
group_inds,
out_inds,
ingroup_counter,
N
);
cudaFree(ingroup_counter);
#ifdef DEBUG
CHECK_CALL(cudaGetLastError());
CHECK_CALL(cudaDeviceSynchronize());
#endif
return;
}
\ No newline at end of file
...@@ -150,3 +150,40 @@ def nms_normal_gpu(boxes, scores, thresh, **kwargs): ...@@ -150,3 +150,40 @@ def nms_normal_gpu(boxes, scores, thresh, **kwargs):
keep = torch.LongTensor(boxes.size(0)) keep = torch.LongTensor(boxes.size(0))
num_out = iou3d_nms_cuda.nms_normal_gpu(boxes, keep, thresh) num_out = iou3d_nms_cuda.nms_normal_gpu(boxes, keep, thresh)
return order[keep[:num_out].cuda()].contiguous(), None return order[keep[:num_out].cuda()].contiguous(), None
def paired_boxes_iou3d_gpu(boxes_a, boxes_b):
"""
Args:
boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
Returns:
ans_iou: (N)
"""
assert boxes_a.shape[0] == boxes_b.shape[0]
assert boxes_a.shape[1] == boxes_b.shape[1] == 7
# height overlap
boxes_a_height_max = (boxes_a[:, 2] + boxes_a[:, 5] / 2).view(-1, 1)
boxes_a_height_min = (boxes_a[:, 2] - boxes_a[:, 5] / 2).view(-1, 1)
boxes_b_height_max = (boxes_b[:, 2] + boxes_b[:, 5] / 2).view(-1, 1)
boxes_b_height_min = (boxes_b[:, 2] - boxes_b[:, 5] / 2).view(-1, 1)
# bev overlap
overlaps_bev = torch.cuda.FloatTensor(torch.Size((boxes_a.shape[0], 1))).zero_() # (N, ``)
iou3d_nms_cuda.paired_boxes_overlap_bev_gpu(boxes_a.contiguous(), boxes_b.contiguous(), overlaps_bev)
max_of_min = torch.max(boxes_a_height_min, boxes_b_height_min)
min_of_max = torch.min(boxes_a_height_max, boxes_b_height_max)
overlaps_h = torch.clamp(min_of_max - max_of_min, min=0)
# 3d iou
overlaps_3d = overlaps_bev * overlaps_h
vol_a = (boxes_a[:, 3] * boxes_a[:, 4] * boxes_a[:, 5]).view(-1, 1)
vol_b = (boxes_b[:, 3] * boxes_b[:, 4] * boxes_b[:, 5]).view(-1, 1)
iou3d = overlaps_3d / torch.clamp(vol_a + vol_b - overlaps_3d, min=1e-6)
return iou3d.view(-1)
\ No newline at end of file
...@@ -41,6 +41,7 @@ const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; ...@@ -41,6 +41,7 @@ const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
void boxesalignedoverlapLauncher(const int num_box, const float *boxes_a, const float *boxes_b, float *ans_overlap); void boxesalignedoverlapLauncher(const int num_box, const float *boxes_a, const float *boxes_b, float *ans_overlap);
void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap); void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap);
void PairedBoxesOverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap);
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou); void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou);
void nmsLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh); void nmsLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh);
void nmsNormalLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh); void nmsNormalLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh);
...@@ -90,6 +91,29 @@ int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans ...@@ -90,6 +91,29 @@ int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans
return 1; return 1;
} }
int paired_boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
// params ans_overlap: (N, 1)
CHECK_INPUT(boxes_a);
CHECK_INPUT(boxes_b);
CHECK_INPUT(ans_overlap);
int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0);
assert(num_a == num_b);
const float * boxes_a_data = boxes_a.data<float>();
const float * boxes_b_data = boxes_b.data<float>();
float * ans_overlap_data = ans_overlap.data<float>();
PairedBoxesOverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_overlap_data);
return 1;
}
int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou){ int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment