Commit 635f1e94 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

all model codes refactoring for training/testing, support multiple 3D...

all model codes refactoring for training/testing, support multiple 3D detectors (SECOND, PartA2-Net (official release), PV-RCNN (official release))
parent 3fdecc87
import torch
import numpy as np
from collections import namedtuple
from .detectors import build_detector
def build_network(model_cfg, num_class, dataset):
model = build_detector(
model_cfg=model_cfg, num_class=num_class, dataset=dataset
)
return model
def model_fn_decorator():
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])
def model_func(model, batch_dict):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
ret_dict, tb_dict, disp_dict = model(batch_dict)
loss = ret_dict['loss'].mean()
if hasattr(model, 'update_global_step'):
model.update_global_step()
else:
model.module.update_global_step()
return ModelReturn(loss, tb_dict, disp_dict)
return model_func
from .base_bev_backbone import BaseBEVBackbone
__all__ = {
'BaseBEVBackbone': BaseBEVBackbone
}
\ No newline at end of file
import torch
import torch.nn as nn
class BaseBEVBackbone(nn.Module):
def __init__(self, model_cfg, input_channels):
super().__init__()
self.model_cfg = model_cfg
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == \
len(self.model_cfg.NUM_FILTERS) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]]
self.blocks = nn.ModuleList()
self.deblocks = nn.ModuleList()
for idx in range(num_levels):
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
c_in_list[idx], num_filters[idx], kernel_size=3,
stride=layer_strides[idx], padding=0, bias=False
),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(layer_nums[idx]):
cur_layers.extend([
nn.Conv2d(num_filters[idx], num_filters[idx], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
nn.ReLU(),
))
self.num_bev_features = c_in
def forward(self, data_dict):
"""
Args:
data_dict:
spatial_features
Returns:
"""
spatial_features = data_dict['spatial_features']
ups = []
ret_dict = {}
x = spatial_features
for i in range(len(self.blocks)):
x = self.blocks[i](x)
stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
ups.append(self.deblocks[i](x))
if len(ups) > 1:
x = torch.cat(ups, dim=1)
else:
x = ups[0]
if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)
data_dict['spatial_features_2d'] = x
return data_dict
from .height_compression import HeightCompression
__all__ = {
'HeightCompression': HeightCompression
}
import torch.nn as nn
class HeightCompression(nn.Module):
def __init__(self, model_cfg):
super().__init__()
self.model_cfg = model_cfg
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
def forward(self, batch_dict):
"""
Args:
batch_dict:
encoded_spconv_tensor: sparse tensor
Returns:
batch_dict:
spatial_features:
"""
encoded_spconv_tensor = batch_dict['encoded_spconv_tensor']
spatial_features = encoded_spconv_tensor.dense()
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
batch_dict['spatial_features'] = spatial_features
batch_dict['spatial_features_stride'] = batch_dict['encoded_spconv_tensor_stride']
return batch_dict
from .spconv_backbone import VoxelBackBone8x
from .spconv_unet import UNetV2
__all__ = {
'VoxelBackBone8x': VoxelBackBone8x,
'UNetV2': UNetV2
}
from .voxel_set_abstraction import VoxelSetAbstraction
__all__ = {
'VoxelSetAbstraction': VoxelSetAbstraction
}
import torch
import torch.nn as nn
from ....utils import common_utils
from ....ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules
from ....ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_stack_utils
def bilinear_interpolate_torch(im, x, y):
"""
Args:
im: (H, W, C) [y, x]
x: (N)
y: (N)
Returns:
"""
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0 = torch.clamp(x0, 0, im.shape[1] - 1)
x1 = torch.clamp(x1, 0, im.shape[1] - 1)
y0 = torch.clamp(y0, 0, im.shape[0] - 1)
y1 = torch.clamp(y1, 0, im.shape[0] - 1)
Ia = im[y0, x0]
Ib = im[y1, x0]
Ic = im[y0, x1]
Id = im[y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)
return ans
class VoxelSetAbstraction(nn.Module):
def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None,
num_rawpoint_features=None, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
SA_cfg = self.model_cfg.SA_LAYER
self.SA_layers = nn.ModuleList()
self.SA_layer_names = []
self.downsample_times_map = {}
c_in = 0
for src_name in self.model_cfg.FEATURES_SOURCE:
if src_name in ['bev', 'raw_points']:
continue
self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR
mlps = SA_cfg[src_name].MLPS
for k in range(len(mlps)):
mlps[k] = [mlps[k][0]] + mlps[k]
cur_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg[src_name].POOL_RADIUS,
nsamples=SA_cfg[src_name].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool',
)
self.SA_layers.append(cur_layer)
self.SA_layer_names.append(src_name)
c_in += sum([x[-1] for x in mlps])
if 'bev' in self.model_cfg.FEATURES_SOURCE:
c_bev = num_bev_features
c_in += c_bev
if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
mlps = SA_cfg['raw_points'].MLPS
for k in range(len(mlps)):
mlps[k] = [num_rawpoint_features - 3] + mlps[k]
self.SA_rawpoints = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg['raw_points'].POOL_RADIUS,
nsamples=SA_cfg['raw_points'].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool'
)
c_in += sum([x[-1] for x in mlps])
self.vsa_point_feature_fusion = nn.Sequential(
nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False),
nn.BatchNorm1d(self.model_cfg.NUM_OUTPUT_FEATURES),
nn.ReLU(),
)
self.num_point_features = self.model_cfg.NUM_OUTPUT_FEATURES
self.num_point_features_before_fusion = c_in
def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride):
x_idxs = (keypoints[:, :, 0] - self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[:, :, 1] - self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_stride
y_idxs = y_idxs / bev_stride
point_bev_features_list = []
for k in range(batch_size):
cur_x_idxs = x_idxs[k]
cur_y_idxs = y_idxs[k]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features.unsqueeze(dim=0))
point_bev_features = torch.cat(point_bev_features_list, dim=0) # (B, N, C0)
return point_bev_features
def get_sampled_points(self, batch_dict):
batch_size = batch_dict['batch_size']
if self.model_cfg.POINT_SOURCE == 'raw_points':
src_points = batch_dict['points'][:, 1:4]
batch_indices = batch_dict['points'][:, 0].long()
elif self.model_cfg.POINT_SOURCE == 'voxel_centers':
src_points = common_utils.get_voxel_centers(
batch_dict['voxel_coords'][:, 1:4],
downsample_times=1,
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
)
batch_indices = batch_dict['voxel_coords'][:, 0].long()
else:
raise NotImplementedError
keypoints_list = []
for bs_idx in range(batch_size):
bs_mask = (batch_indices == bs_idx)
sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3)
if self.model_cfg.SAMPLE_METHOD == 'FPS':
cur_pt_idxs = pointnet2_stack_utils.furthest_point_sample(
sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS
).long()
if sampled_points.shape[1] < self.model_cfg.NUM_KEYPOINTS:
empty_num = self.model_cfg.NUM_KEYPOINTS - sampled_points.shape[1]
cur_pt_idxs[0, -empty_num:] = cur_pt_idxs[0, :empty_num]
keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0)
elif self.model_cfg.SAMPLE_METHOD == 'FastFPS':
raise NotImplementedError
else:
raise NotImplementedError
keypoints_list.append(keypoints)
keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3)
return keypoints
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
keypoints: (B, num_keypoints, 3)
multi_scale_3d_features: {
'x_conv4': ...
}
points: optional (N, 1 + 3 + C) [bs_idx, x, y, z, ...]
spatial_features: optional
spatial_features_stride: optional
Returns:
point_features: (N, C)
point_coords: (N, 4)
"""
keypoints = self.get_sampled_points(batch_dict)
point_features_list = []
if 'bev' in self.model_cfg.FEATURES_SOURCE:
point_bev_features = self.interpolate_from_bev_features(
keypoints, batch_dict['spatial_features'], batch_dict['batch_size'],
bev_stride=batch_dict['spatial_features_stride']
)
point_features_list.append(point_bev_features)
batch_size, num_keypoints, _ = keypoints.shape
new_xyz = keypoints.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int().fill_(num_keypoints)
if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
raw_points = batch_dict['points']
xyz = raw_points[:, 1:4]
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (raw_points[:, 0] == bs_idx).sum()
point_features = raw_points[:, 4:].contiguous() if len(raw_points) > 4 else None
pooled_points, pooled_features = self.SA_rawpoints(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=point_features,
)
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
for k, src_name in enumerate(self.SA_layer_names):
cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices
xyz = common_utils.get_voxel_centers(
cur_coords[:, 1:4],
downsample_times=self.downsample_times_map[src_name],
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = self.SA_layers[k](
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=batch_dict['multi_scale_3d_features'][src_name].features.contiguous(),
)
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
point_features = torch.cat(point_features_list, dim=2)
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1)
point_coords = torch.cat((batch_idx.view(-1, 1).float(), keypoints.view(-1, 3)), dim=1)
batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1])
point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1]))
batch_dict['point_features'] = point_features # (BxN, C)
batch_dict['point_coords'] = point_coords # (BxN, 4)
return batch_dict
import torch.nn as nn
import spconv
from functools import partial
def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
conv_type='subm', norm_fn=None):
if conv_type == 'subm':
conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key)
elif conv_type == 'spconv':
conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
bias=False, indice_key=indice_key)
elif conv_type == 'inverseconv':
conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False)
else:
raise NotImplementedError
m = spconv.SparseSequential(
conv,
norm_fn(out_channels),
nn.ReLU(),
)
return m
class VoxelBackBone8x(nn.Module):
def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__()
self.model_cfg = model_cfg
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
self.sparse_shape = grid_size[::-1] + [1, 0, 0]
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'),
norm_fn(16),
nn.ReLU(),
)
block = post_act_block
self.conv1 = spconv.SparseSequential(
block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'),
)
self.conv2 = spconv.SparseSequential(
# [1600, 1408, 41] <- [800, 704, 21]
block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
)
self.conv3 = spconv.SparseSequential(
# [800, 704, 21] <- [400, 352, 11]
block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
)
self.conv4 = spconv.SparseSequential(
# [400, 352, 11] <- [200, 176, 5]
block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
)
last_pad = 0
last_pad = self.model_cfg.get('last_pad', last_pad)
self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
bias=False, indice_key='spconv_down2'),
norm_fn(128),
nn.ReLU(),
)
self.num_point_features = 128
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
"""
voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
batch_size = batch_dict['batch_size']
input_sp_tensor = spconv.SparseConvTensor(
features=voxel_features,
indices=voxel_coords.int(),
spatial_shape=self.sparse_shape,
batch_size=batch_size
)
x = self.conv_input(input_sp_tensor)
x_conv1 = self.conv1(x)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
batch_dict.update({
'encoded_spconv_tensor': out,
'encoded_spconv_tensor_stride': 8
})
batch_dict.update({
'multi_scale_3d_features': {
'x_conv1': x_conv1,
'x_conv2': x_conv2,
'x_conv3': x_conv3,
'x_conv4': x_conv4,
}
})
return batch_dict
import torch
import torch.nn as nn
import spconv
from functools import partial
from .spconv_backbone import post_act_block
from ...utils import common_utils
class SparseBasicBlock(spconv.SparseModule):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, indice_key=None, norm_fn=None):
super(SparseBasicBlock, self).__init__()
self.conv1 = spconv.SubMConv3d(
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, indice_key=indice_key
)
self.bn1 = norm_fn(planes)
self.relu = nn.ReLU()
self.conv2 = spconv.SubMConv3d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False, indice_key=indice_key
)
self.bn2 = norm_fn(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x.features
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.bn2(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity
out.features = self.relu(out.features)
return out
class UNetV2(nn.Module):
"""
Sparse Convolution based UNet for point-wise feature learning.
Reference Paper: https://arxiv.org/abs/1907.03670 (Shaoshuai Shi, et. al)
From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
"""
def __init__(self, model_cfg, input_channels, grid_size, voxel_size, point_cloud_range, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.sparse_shape = grid_size[::-1] + [1, 0, 0]
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'),
norm_fn(16),
nn.ReLU(),
)
block = post_act_block
self.conv1 = spconv.SparseSequential(
block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'),
)
self.conv2 = spconv.SparseSequential(
# [1600, 1408, 41] <- [800, 704, 21]
block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
)
self.conv3 = spconv.SparseSequential(
# [800, 704, 21] <- [400, 352, 11]
block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
)
self.conv4 = spconv.SparseSequential(
# [400, 352, 11] <- [200, 176, 5]
block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
)
last_pad = 0
last_pad = self.model_cfg.get('last_pad', last_pad)
self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
bias=False, indice_key='spconv_down2'),
norm_fn(128),
nn.ReLU(),
)
# decoder
# [400, 352, 11] <- [200, 176, 5]
self.conv_up_t4 = SparseBasicBlock(64, 64, indice_key='subm4', norm_fn=norm_fn)
self.conv_up_m4 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4')
self.inv_conv4 = block(64, 64, 3, norm_fn=norm_fn, indice_key='spconv4', conv_type='inverseconv')
# [800, 704, 21] <- [400, 352, 11]
self.conv_up_t3 = SparseBasicBlock(64, 64, indice_key='subm3', norm_fn=norm_fn)
self.conv_up_m3 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3')
self.inv_conv3 = block(64, 32, 3, norm_fn=norm_fn, indice_key='spconv3', conv_type='inverseconv')
# [1600, 1408, 41] <- [800, 704, 21]
self.conv_up_t2 = SparseBasicBlock(32, 32, indice_key='subm2', norm_fn=norm_fn)
self.conv_up_m2 = block(64, 32, 3, norm_fn=norm_fn, indice_key='subm2')
self.inv_conv2 = block(32, 16, 3, norm_fn=norm_fn, indice_key='spconv2', conv_type='inverseconv')
# [1600, 1408, 41] <- [1600, 1408, 41]
self.conv_up_t1 = SparseBasicBlock(16, 16, indice_key='subm1', norm_fn=norm_fn)
self.conv_up_m1 = block(32, 16, 3, norm_fn=norm_fn, indice_key='subm1')
self.conv5 = spconv.SparseSequential(
block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1')
)
self.num_point_features = 16
def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
x_trans = conv_t(x_lateral)
x = x_trans
x.features = torch.cat((x_bottom.features, x_trans.features), dim=1)
x_m = conv_m(x)
x = self.channel_reduction(x, x_m.features.shape[1])
x.features = x_m.features + x.features
x = conv_inv(x)
return x
@staticmethod
def channel_reduction(x, out_channels):
"""
Args:
x: x.features (N, C1)
out_channels: C2
Returns:
"""
features = x.features
n, in_channels = features.shape
assert (in_channels % out_channels == 0) and (in_channels >= out_channels)
x.features = features.view(n, out_channels, -1).sum(dim=2)
return x
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
batch_size = batch_dict['batch_size']
input_sp_tensor = spconv.SparseConvTensor(
features=voxel_features,
indices=voxel_coords.int(),
spatial_shape=self.sparse_shape,
batch_size=batch_size
)
x = self.conv_input(input_sp_tensor)
x_conv1 = self.conv1(x)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
# for segmentation head
# [400, 352, 11] <- [200, 176, 5]
x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4, self.conv_up_m4, self.inv_conv4)
# [800, 704, 21] <- [400, 352, 11]
x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3, self.conv_up_m3, self.inv_conv3)
# [1600, 1408, 41] <- [800, 704, 21]
x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2, self.conv_up_m2, self.inv_conv2)
# [1600, 1408, 41] <- [1600, 1408, 41]
x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1, self.conv_up_m1, self.conv5)
batch_dict['point_features'] = x_up1.features
point_coords = common_utils.get_voxel_centers(
x_up1.indices[:, 1:], downsample_times=1, voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
)
batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
return batch_dict
from .vfe_template import VFETemplate
from .mean_vfe import MeanVFE
__all__ = {
'VFETemplate': VFETemplate,
'MeanVFE': MeanVFE
}
import torch
from .vfe_template import VFETemplate
class MeanVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, **kwargs):
super().__init__(model_cfg=model_cfg)
self.num_point_features = num_point_features
def get_output_feature_dim(self):
return self.num_point_features
def forward(self, batch_dict, **kwargs):
"""
Args:
batch_dict:
voxels: (num_voxels, max_points_per_voxel, C)
voxel_num_points: optional (num_voxels)
**kwargs:
Returns:
vfe_features: (num_voxels, C)
"""
voxel_features, voxel_num_points = batch_dict['voxels'], batch_dict['voxel_num_points']
points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False)
normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features)
points_mean = points_mean / normalizer
batch_dict['voxel_features'] = points_mean.contiguous()
return batch_dict
import torch.nn as nn
class VFETemplate(nn.Module):
def __init__(self, model_cfg, **kwargs):
super().__init__()
self.model_cfg = model_cfg
def get_output_feature_dim(self):
raise NotImplementedError
def forward(self, **kwargs):
"""
Args:
**kwargs:
Returns:
batch_dict:
...
vfe_features: (num_voxels, C)
"""
raise NotImplementedError
from .anchor_head_template import AnchorHeadTemplate
from .anchor_head_single import AnchorHeadSingle
from .point_intra_part_head import PointIntraPartOffsetHead
from .point_head_simple import PointHeadSimple
__all__ = {
'AnchorHeadTemplate': AnchorHeadTemplate,
'AnchorHeadSingle': AnchorHeadSingle,
'PointIntraPartOffsetHead': PointIntraPartOffsetHead,
'PointHeadSimple': PointHeadSimple
}
import numpy as np
import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate
class AnchorHeadSingle(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, grid_size, point_cloud_range,
predict_boxes_when_training=True):
super().__init__(
model_cfg=model_cfg, num_class=num_class, grid_size=grid_size, point_cloud_range=point_cloud_range,
predict_boxes_when_training=predict_boxes_when_training
)
self.num_anchors_per_location = sum(self.num_anchors_per_location)
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
self.conv_box = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.box_coder.code_size,
kernel_size=1
)
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
self.conv_dir_cls = nn.Conv2d(
input_channels,
self.num_anchors_per_location * self.model_cfg.NUM_DIR_BINS,
kernel_size=1
)
else:
self.conv_dir_cls = None
self.init_weights()
def init_weights(self):
pi = 0.01
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
nn.init.normal_(self.conv_box.weight, mean=0, std=0.001)
def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d']
cls_preds = self.conv_cls(spatial_features_2d)
box_preds = self.conv_box(spatial_features_2d)
cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous() # [N, H, W, C]
box_preds = box_preds.permute(0, 2, 3, 1).contiguous() # [N, H, W, C]
self.forward_ret_dict['cls_preds'] = cls_preds
self.forward_ret_dict['box_preds'] = box_preds
if self.conv_dir_cls is not None:
dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
self.forward_ret_dict['dir_cls_preds'] = dir_cls_preds
else:
dir_cls_preds = None
if self.training:
targets_dict = self.assign_targets(
gt_boxes=data_dict['gt_boxes']
)
self.forward_ret_dict.update(targets_dict)
if not self.training or self.predict_boxes_when_training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=data_dict['batch_size'],
cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
)
data_dict['batch_cls_preds'] = batch_cls_preds
data_dict['batch_box_preds'] = batch_box_preds
data_dict['cls_preds_normalized'] = False
return data_dict
import numpy as np
import torch
import torch.nn as nn
from .target_assigner.anchor_generator import AnchorGenerator
from .target_assigner.atss_target_assigner import ATSSTargetAssigner
from ...utils import box_coder_utils, loss_utils, common_utils
class AnchorHeadTemplate(nn.Module):
def __init__(self, model_cfg, num_class, grid_size, point_cloud_range, predict_boxes_when_training):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.predict_boxes_when_training = predict_boxes_when_training
self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False)
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6)
)
anchors, self.num_anchors_per_location = self.generate_anchors(
self.model_cfg.ANCHOR_GENERATOR_CONFIG, grid_size=grid_size, point_cloud_range=point_cloud_range
)
self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg)
self.forward_ret_dict = {}
self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range):
anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg
)
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg):
if anchor_target_cfg.NAME == 'ATSS':
target_assigner = ATSSTargetAssigner(
topk=anchor_target_cfg.TOPK,
box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT
)
elif anchor_target_cfg.NAME == 'Second':
target_assigner = SecondTargetAssigner(
anchor_target_cfg=anchor_target_cfg,
box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT
)
else:
raise NotImplementedError
return target_assigner
def build_losses(self, losses_cfg):
self.add_module(
'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
)
self.add_module(
'reg_loss_func',
loss_utils.WeightedSmoothL1Loss(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
)
self.add_module(
'dir_loss_func',
loss_utils.WeightedCrossEntropyLoss()
)
def assign_targets(self, gt_boxes):
"""
Args:
gt_boxes: (B, M, 8)
Returns:
"""
targets_dict = self.target_assigner.assign_targets(
self.anchors, gt_boxes, self.use_multihead
)
return targets_dict
def get_cls_layer_loss(self):
cls_preds = self.forward_ret_dict['cls_preds']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
batch_size = int(cls_preds.shape[0])
cared = box_cls_labels >= 0 # [N, num_anchors]
positives = box_cls_labels > 0
negatives = box_cls_labels == 0
negative_cls_weights = negatives * 1.0
cls_weights = (negative_cls_weights + 1.0 * positives).float()
reg_weights = positives.float()
if self.num_class == 1:
# class agnostic
box_cls_labels[positives] = 1
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
cls_targets = cls_targets.unsqueeze(dim=-1)
cls_targets = cls_targets.squeeze(dim=-1)
one_hot_targets = torch.zeros(
*list(cls_targets.shape), self.num_class + 1, dtype=cls_preds.dtype, device=cls_targets.device
)
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
cls_preds = cls_preds.view(batch_size, -1, self.num_class)
one_hot_targets = one_hot_targets[..., 1:]
# import pdb
# pdb.set_trace()
cls_loss_src = self.cls_loss_func(cls_preds, one_hot_targets, weights=cls_weights) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
tb_dict = {
'rpn_loss_cls': cls_loss.item()
}
return cls_loss, tb_dict
@staticmethod
def add_sin_difference(boxes1, boxes2, dim=6):
assert dim != -1
rad_pred_encoding = torch.sin(boxes1[..., dim:dim + 1]) * torch.cos(boxes2[..., dim:dim + 1])
rad_tg_encoding = torch.cos(boxes1[..., dim:dim + 1]) * torch.sin(boxes2[..., dim:dim + 1])
boxes1 = torch.cat([boxes1[..., :dim], rad_pred_encoding, boxes1[..., dim + 1:]], dim=-1)
boxes2 = torch.cat([boxes2[..., :dim], rad_tg_encoding, boxes2[..., dim + 1:]], dim=-1)
return boxes1, boxes2
@staticmethod
def get_direction_target(anchors, reg_targets, one_hot=True, dir_offset=0, num_bins=2):
batch_size = reg_targets.shape[0]
anchors = anchors.view(batch_size, -1, anchors.shape[-1])
rot_gt = reg_targets[..., 6] + anchors[..., 6]
offset_rot = common_utils.limit_period(rot_gt - dir_offset, 0, 2 * np.pi)
dir_cls_targets = torch.floor(offset_rot / (2 * np.pi / num_bins)).long()
dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1)
if one_hot:
dir_targets = torch.zeros(*list(dir_cls_targets.shape), num_bins, dtype=anchors.dtype,
device=dir_cls_targets.device)
dir_targets.scatter_(-1, dir_cls_targets.unsqueeze(dim=-1).long(), 1.0)
dir_cls_targets = dir_targets
return dir_cls_targets
def get_box_reg_layer_loss(self):
box_preds = self.forward_ret_dict['box_preds']
box_dir_cls_preds = self.forward_ret_dict.get('dir_cls_preds', None)
box_reg_targets = self.forward_ret_dict['box_reg_targets']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
batch_size = int(box_preds.shape[0])
positives = box_cls_labels > 0
reg_weights = positives.float()
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
if isinstance(self.anchors, list):
if self.use_multihead:
anchors = torch.cat(
[anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1]) for anchor in
self.anchors], dim=0)
else:
anchors = torch.cat(self.anchors, dim=-3)
else:
anchors = self.anchors
anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
box_preds = box_preds.view(batch_size, -1,
box_preds.shape[-1] // self.num_anchors_per_location if not self.use_multihead else
box_preds.shape[-1])
# sin(a - b) = sinacosb-cosasinb
box_preds_sin, reg_targets_sin = self.add_sin_difference(box_preds, box_reg_targets)
loc_loss_src = self.reg_loss_func(box_preds_sin, reg_targets_sin, weights=reg_weights) # [N, M]
loc_loss = loc_loss_src.sum() / batch_size
loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
box_loss = loc_loss
tb_dict = {
'rpn_loss_loc': loc_loss.item()
}
if box_dir_cls_preds is not None:
dir_targets = self.get_direction_target(
anchors, box_reg_targets,
dir_offset=self.model_cfg.DIR_OFFSET,
num_bins=self.model_cfg.NUM_DIR_BINS
)
dir_logits = box_dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
weights = positives.type_as(dir_logits)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
dir_loss = self.dir_loss_func(dir_logits, dir_targets, weights=weights)
dir_loss = dir_loss.sum() / batch_size
dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight']
box_loss += dir_loss
tb_dict['rpn_loss_dir'] = dir_loss.item()
return box_loss, tb_dict
def get_loss(self):
cls_loss, tb_dict = self.get_cls_layer_loss()
box_loss, tb_dict_box = self.get_box_reg_layer_loss()
tb_dict.update(tb_dict_box)
rpn_loss = cls_loss + box_loss
tb_dict['rpn_loss'] = rpn_loss.item()
return rpn_loss, tb_dict
def generate_predicted_boxes(self, batch_size, cls_preds, box_preds, dir_cls_preds=None):
"""
Args:
batch_size:
cls_preds: (N, H, W, C1)
box_preds: (N, H, W, C2)
dir_cls_preds: (N, H, W, C3)
Returns:
batch_cls_preds: (B, num_boxes, num_classes)
batch_box_preds: (B, num_boxes, 7+C)
"""
if isinstance(self.anchors, list):
if self.use_multihead:
anchors = torch.cat([anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1])
for anchor in self.anchors], dim=0)
else:
anchors = torch.cat(self.anchors, dim=-3)
else:
anchors = self.anchors
num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0]
batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float()
batch_box_preds = box_preds.view(batch_size, num_anchors, -1)
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors)
if dir_cls_preds is not None:
dir_offset = self.model_cfg.DIR_OFFSET
dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1)
dir_labels = torch.max(dir_cls_preds, dim=-1)[1]
period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS)
dir_rot = common_utils.limit_period(
batch_box_preds[..., 6] - dir_offset, dir_limit_offset, period
)
batch_box_preds[..., 6] = dir_rot + dir_offset + period * dir_labels.to(batch_box_preds.dtype)
if isinstance(self.box_coder, box_coder_utils.PreviousResidualDecoder):
batch_box_preds[..., 6] = common_utils.limit_period(
-(batch_box_preds[..., 6] + np.pi / 2), offset=0.5, period=np.pi * 2
)
return batch_cls_preds, batch_box_preds
def forward(self, **kwargs):
raise NotImplementedError
import torch
from .point_head_template import PointHeadTemplate
from ...utils import box_utils
class PointHeadSimple(PointHeadTemplate):
"""
A simple point-based segmentation head, which are used for PV-RCNN keypoint segmentaion.
Reference Paper: https://arxiv.org/abs/1912.13192
PV-RCNN: Point-Voxel Feature Set Abstraction for 3D Object Detection
"""
def __init__(self, num_class, input_channels, model_cfg, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class)
self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels,
output_channels=num_class
)
def assign_targets(self, input_dict):
"""
Args:
input_dict:
point_features: (N1 + N2 + N3 + ..., C)
batch_size:
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes (optional): (B, M, 8)
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_part_labels: (N1 + N2 + N3 + ..., 3)
"""
point_coords = input_dict['point_coords']
gt_boxes = input_dict['gt_boxes']
assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)
batch_size = gt_boxes.shape[0]
extend_gt_boxes = box_utils.enlarge_box3d(
gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH
).view(batch_size, -1, gt_boxes.shape[-1])
targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=False
)
return targets_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss()
point_loss = point_loss_cls
tb_dict.update(tb_dict_1)
return point_loss, tb_dict
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
point_features_before_fusion: (N1 + N2 + N3 + ..., C)
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
point_labels (optional): (N1 + N2 + N3 + ...)
gt_boxes (optional): (B, M, 8)
Returns:
batch_dict:
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
"""
if self.model_cfg.get('USE_POINT_FEATURES_BEFORE_FUSION', False):
point_features = batch_dict['point_features_before_fusion']
else:
point_features = batch_dict['point_features']
point_cls_preds = self.cls_layers(point_features) # (total_points, num_class)
ret_dict = {
'point_cls_preds': point_cls_preds,
}
point_cls_scores = torch.sigmoid(point_cls_preds)
batch_dict['point_cls_scores'], _ = point_cls_scores.max(dim=-1)
if self.training:
targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
self.forward_ret_dict = ret_dict
return batch_dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...utils import loss_utils, common_utils
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
class PointHeadTemplate(nn.Module):
def __init__(self, model_cfg, num_class):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None
def build_losses(self, losses_cfg):
self.add_module(
'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
)
self.reg_loss_func = F.smooth_l1_loss if losses_cfg.get('LOSS_REG', None) == 'smooth-l1' else F.l1_loss
@staticmethod
def make_fc_layers(fc_cfg, input_channels, output_channels):
fc_layers = []
c_in = input_channels
for k in range(0, fc_cfg.__len__()):
fc_layers.extend([
nn.Linear(c_in, fc_cfg[k], bias=False),
nn.BatchNorm1d(fc_cfg[k]),
nn.ReLU(),
])
c_in = fc_cfg[k]
fc_layers.append(nn.Linear(c_in, output_channels, bias=True))
return nn.Sequential(*fc_layers)
def assign_stack_targets(self, points, gt_boxes, extend_gt_boxes=None,
ret_box_labels=False, ret_part_labels=False,
set_ignore_flag=True, use_ball_constraint=False, central_radius=2.0):
"""
Args:
points: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes: (B, M, 8)
extend_gt_boxes: [B, M, 8]
ret_box_labels:
ret_part_labels:
set_ignore_flag:
use_ball_constraint:
central_radius:
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_box_labels: (N1 + N2 + N3 + ..., code_size)
"""
assert len(points.shape) == 2 and points.shape[1] == 4, 'points.shape=%s' % str(points.shape)
assert len(gt_boxes.shape) == 3 and gt_boxes.shape[2] == 8, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert extend_gt_boxes is None or len(extend_gt_boxes.shape) == 3 and extend_gt_boxes.shape[2] == 8, \
'extend_gt_boxes.shape=%s' % str(extend_gt_boxes.shape)
assert set_ignore_flag != use_ball_constraint, 'Choose one only!'
batch_size = gt_boxes.shape[0]
bs_idx = points[:, 0]
point_cls_labels = points.new_zeros(points.shape[0]).long()
point_box_labels = gt_boxes.new_zeros((points.shape[0], 8)) if ret_box_labels else None
point_part_labels = gt_boxes.new_zeros((points.shape[0], 3)) if ret_part_labels else None
for k in range(batch_size):
bs_mask = (bs_idx == k)
points_single = points[bs_mask][:, 1:4]
point_cls_labels_single = point_cls_labels.new_zeros(bs_mask.sum())
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
points_single.unsqueeze(dim=0), gt_boxes[k:k + 1, :, 0:7].contiguous()
).long().squeeze(dim=0)
box_fg_flag = (box_idxs_of_pts >= 0)
if set_ignore_flag:
extend_box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
points_single.unsqueeze(dim=0), extend_gt_boxes[k:k+1, :, 0:7].contiguous()
).long().squeeze(dim=0)
fg_flag = box_fg_flag
ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
point_cls_labels_single[ignore_flag] = -1
elif use_ball_constraint:
box_centers = gt_boxes[k][box_idxs_of_pts][:, 0:3].clone()
box_centers[:, 2] += gt_boxes[k][box_idxs_of_pts][:, 5] / 2
ball_flag = ((box_centers - points_single).norm(dim=1) < central_radius)
fg_flag = box_fg_flag & ball_flag
else:
raise NotImplementedError
gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, 7].long()
point_cls_labels[bs_mask] = point_cls_labels_single
if ret_box_labels:
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
fg_point_box_labels = self.box_coder.encode_torch(points_single[fg_flag], gt_box_of_fg_points)
point_box_labels_single[fg_flag] = fg_point_box_labels
point_box_labels[bs_mask] = point_box_labels_single
if ret_part_labels:
point_part_labels_single = point_part_labels.new_zeros((bs_mask.sum(), 3))
transformed_points = points_single[fg_flag] - gt_box_of_fg_points[:, 0:3]
transformed_points = common_utils.rotate_points_along_z(
transformed_points.view(-1, 1, 3), -gt_box_of_fg_points[:, 6]
).view(-1, 3)
offset = torch.tensor([0.5, 0.5, 0.5]).view(1, 3).type_as(transformed_points)
point_part_labels_single[fg_flag] = (transformed_points / gt_box_of_fg_points[:, 3:6]) + offset
point_part_labels[bs_mask] = point_part_labels_single
targets_dict = {
'point_cls_labels': point_cls_labels,
'point_box_labels': point_box_labels,
'point_part_labels': point_part_labels
}
return targets_dict
def get_cls_layer_loss(self):
point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1)
point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class)
positives = (point_cls_labels > 0)
negative_cls_weights = (point_cls_labels == 0) * 1.0
cls_weights = (negative_cls_weights + 1.0 * positives).float()
pos_normalizer = positives.sum(dim=0).float()
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
one_hot_targets = point_cls_preds.new_zeros(*list(point_cls_labels.shape), self.num_class + 1)
one_hot_targets.scatter_(-1, (point_cls_labels * (point_cls_labels >= 0).long()).unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_targets[..., 1:]
cls_loss_src = self.cls_loss_func(point_cls_preds, one_hot_targets, weights=cls_weights)
point_loss_cls = cls_loss_src.sum()
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight']
tb_dict = {
'point_loss_cls': point_loss_cls.item(),
'point_pos_num': pos_normalizer.item()
}
return point_loss_cls, tb_dict
def get_part_layer_loss(self):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
pos_normalizer = max(1, (pos_mask > 0).sum().item())
point_part_labels = self.forward_ret_dict['point_part_labels']
point_part_preds = self.forward_ret_dict['point_part_preds']
point_loss_part = F.binary_cross_entropy(torch.sigmoid(point_part_preds), point_part_labels, reduction='none')
point_loss_part = (point_loss_part.sum(dim=-1) * pos_mask.float()).sum() / (3 * pos_normalizer)
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
return point_loss_part, {'point_loss_part': point_loss_part.item()}
def forward(self, **kwargs):
raise NotImplementedError
import torch
from .point_head_template import PointHeadTemplate
from ...utils import box_utils
class PointIntraPartOffsetHead(PointHeadTemplate):
"""
Point-based head for predicting the intra-object part locations.
Reference Paper: https://arxiv.org/abs/1907.03670
From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
"""
def __init__(self, num_class, input_channels, model_cfg, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class)
self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels,
output_channels=num_class
)
self.part_reg_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.PART_FC,
input_channels=input_channels,
output_channels=3
)
def assign_targets(self, input_dict):
"""
Args:
input_dict:
point_features: (N1 + N2 + N3 + ..., C)
batch_size:
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes (optional): (B, M, 8)
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_part_labels: (N1 + N2 + N3 + ..., 3)
"""
point_coords = input_dict['point_coords']
gt_boxes = input_dict['gt_boxes']
assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)
batch_size = gt_boxes.shape[0]
extend_gt_boxes = box_utils.enlarge_box3d(
gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH
).view(batch_size, -1, gt_boxes.shape[-1])
targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=True
)
return targets_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss()
point_loss_part, tb_dict_2 = self.get_part_layer_loss()
point_loss = point_loss_cls + point_loss_part
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2)
return point_loss, tb_dict
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
point_labels (optional): (N1 + N2 + N3 + ...)
gt_boxes (optional): (B, M, 8)
Returns:
batch_dict:
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
"""
point_features = batch_dict['point_features']
point_cls_preds = self.cls_layers(point_features) # (total_points, num_class)
point_part_preds = self.part_reg_layers(point_features)
ret_dict = {
'point_cls_preds': point_cls_preds,
'point_part_preds': point_part_preds,
}
point_cls_scores = torch.sigmoid(point_cls_preds)
point_part_offset = torch.sigmoid(point_part_preds)
batch_dict['point_cls_scores'], _ = point_cls_scores.max(dim=-1)
batch_dict['point_part_offset'] = point_part_offset
if self.training:
targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_part_labels'] = targets_dict.get('point_part_labels')
self.forward_ret_dict = ret_dict
return batch_dict
import torch
class AnchorGenerator(object):
def __init__(self, anchor_range, anchor_generator_config):
super().__init__()
self.anchor_generator_cfg = anchor_generator_config
self.anchor_range = anchor_range
self.anchor_sizes = [config['anchor_sizes'] for config in anchor_generator_config]
self.anchor_rotations = [config['anchor_rotations'] for config in anchor_generator_config]
self.anchor_heights = [config['anchor_bottom_heights'] for config in anchor_generator_config]
self.align_center = [config.get('align_center', False) for config in anchor_generator_config]
assert len(self.anchor_sizes) == len(self.anchor_rotations) == len(self.anchor_heights)
self.num_of_anchor_sets = len(self.anchor_sizes)
def generate_anchors(self, grid_sizes):
assert len(grid_sizes) == self.num_of_anchor_sets
all_anchors = []
num_anchors_per_location = []
for grid_size, anchor_size, anchor_rotation, anchor_height, align_center in zip(
grid_sizes, self.anchor_sizes, self.anchor_rotations, self.anchor_heights, self.align_center):
num_anchors_per_location.append(len(anchor_rotation) * len(anchor_size) * len(anchor_height))
if align_center:
x_stride = (self.anchor_range[3] - self.anchor_range[0]) / grid_size[0]
y_stride = (self.anchor_range[4] - self.anchor_range[1]) / grid_size[1]
x_offset, y_offset = x_stride / 2, y_stride / 2
else:
x_stride = (self.anchor_range[3] - self.anchor_range[0]) / (grid_size[0] - 1)
y_stride = (self.anchor_range[4] - self.anchor_range[1]) / (grid_size[1] - 1)
x_offset, y_offset = 0, 0
x_shifts = torch.arange(
self.anchor_range[0] + x_offset, self.anchor_range[3] + 1e-5, step=x_stride, dtype=torch.float32,
).cuda()
y_shifts = torch.arange(
self.anchor_range[1] + y_offset, self.anchor_range[4] + 1e-5, step=y_stride, dtype=torch.float32,
).cuda()
z_shifts = x_shifts.new_tensor(anchor_height)
num_anchor_size, num_anchor_rotation = anchor_size.__len__(), anchor_rotation.__len__()
anchor_rotation = x_shifts.new_tensor(anchor_rotation)
anchor_size = x_shifts.new_tensor(anchor_size)
x_shifts, y_shifts, z_shifts = torch.meshgrid([
x_shifts, y_shifts, z_shifts
]) # [x_grid, y_grid, z_grid]
anchors = torch.stack((x_shifts, y_shifts, z_shifts), dim=-1) # [x, y, z, 3]
anchors = anchors[:, :, :, None, :].repeat(1, 1, 1, anchor_size.shape[0], 1)
anchor_size = anchor_size.view(1, 1, 1, -1, 3).repeat([*anchors.shape[0:3], 1, 1])
anchors = torch.cat((anchors, anchor_size), dim=-1)
anchors = anchors[:, :, :, :, None, :].repeat(1, 1, 1, 1, num_anchor_rotation, 1)
anchor_rotation = anchor_rotation.view(1, 1, 1, 1, -1, 1).repeat([*anchors.shape[0:3], num_anchor_size, 1, 1])
anchors = torch.cat((anchors, anchor_rotation), dim=-1) # [x, y, z, num_size, num_rot, 7]
anchors = anchors.permute(2, 1, 0, 3, 4, 5).contiguous()
#anchors = anchors.view(-1, anchors.shape[-1])
anchors[..., 2] += anchors[..., 5] / 2 # shift to box centers
all_anchors.append(anchors)
return all_anchors, num_anchors_per_location
if __name__ == '__main__':
from easydict import EasyDict
config = [
EasyDict({
'anchor_sizes': [[2.1, 4.7, 1.7], [0.86, 0.91, 1.73], [0.84, 1.78, 1.78]],
'anchor_rotations': [0, 1.57],
'anchor_heights': [0, 0.5]
})
]
A = AnchorGenerator(
anchor_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
anchor_generator_config=config
)
import pdb
pdb.set_trace()
A.generate_anchors([[188, 188]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment