Commit a90d9375 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support PointNet2MSG backbone_3d for PointRCNN

parent adbb322f
from .spconv_backbone import VoxelBackBone8x
from .spconv_unet import UNetV2
from .pointnet2_backbone import PointNet2Backbone
from .pointnet2_backbone import PointNet2Backbone, PointNet2MSG
__all__ = {
'VoxelBackBone8x': VoxelBackBone8x,
'UNetV2': UNetV2,
'PointNet2Backbone': PointNet2Backbone
'PointNet2Backbone': PointNet2Backbone,
'PointNet2MSG': PointNet2MSG
}
import torch
import torch.nn as nn
from ...ops.pointnet2.pointnet2_stack import pointnet2_modules, pointnet2_utils
from ...ops.pointnet2.pointnet2_batch import pointnet2_modules
from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_modules_stack
from ...ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_utils_stack
class PointNet2MSG(nn.Module):
def __init__(self, model_cfg, input_channels, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.SA_modules = nn.ModuleList()
channel_in = input_channels - 3
self.num_points_each_layer = []
skip_channel_list = [input_channels - 3]
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
pointnet2_modules.PointnetSAModuleMSG(
npoint=self.model_cfg.SA_CONFIG.NPOINTS[k],
radii=self.model_cfg.SA_CONFIG.RADIUS[k],
nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True),
)
)
skip_channel_list.append(channel_out)
channel_in = channel_out
self.FP_modules = nn.ModuleList()
for k in range(self.model_cfg.FP_MLPS.__len__()):
pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out
self.FP_modules.append(
pointnet2_modules.PointnetFPModule(
mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k]
)
)
self.num_point_features = self.model_cfg.FP_MLPS[0][-1]
def break_up_pc(self, pc):
batch_idx = pc[:, 0]
xyz = pc[:, 1:4].contiguous()
features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None)
return batch_idx, xyz, features
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
points: (num_points, 4 + C), [batch_idx, x, y, z, ...]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
batch_size = batch_dict['batch_size']
points = batch_dict['points']
batch_idx, xyz, features = self.break_up_pc(points)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
assert xyz_batch_cnt.min() == xyz_batch_cnt.max()
xyz = xyz.view(batch_size, -1, 3)
features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1) if features is not None else None
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
) # (B, C, N)
point_features = l_features[0].permute(0, 2, 1).contiguous() # (B, N, C)
batch_dict['point_features'] = point_features.view(-1, point_features.shape[-1])
batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0].view(-1, 3)), dim=1)
return batch_dict
class PointNet2Backbone(nn.Module):
"""
DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723
"""
def __init__(self, model_cfg, input_channels, **kwargs):
assert False, 'DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723'
super().__init__()
self.model_cfg = model_cfg
......@@ -22,7 +116,7 @@ class PointNet2Backbone(nn.Module):
channel_out += mlps[idx][-1]
self.SA_modules.append(
pointnet2_modules.StackSAModuleMSG(
pointnet2_modules_stack.StackSAModuleMSG(
radii=self.model_cfg.SA_CONFIG.RADIUS[k],
nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
......@@ -37,7 +131,7 @@ class PointNet2Backbone(nn.Module):
for k in range(self.model_cfg.FP_MLPS.__len__()):
pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out
self.FP_modules.append(
pointnet2_modules.StackPointnetFPModule(
pointnet2_modules_stack.StackPointnetFPModule(
mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k]
)
)
......@@ -79,7 +173,7 @@ class PointNet2Backbone(nn.Module):
else:
last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils.furthest_point_sample(
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]:
......
/*
Stacked-batch-data version of point interpolation, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
......
/*
Stacked-batch-data version of point interpolation, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
......
......@@ -7,6 +7,12 @@ DATA_CONFIG:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: sample_points
NUM_POINTS: {
'train': 16384,
'test': 16384
}
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
......@@ -17,7 +23,7 @@ MODEL:
NAME: PointRCNN
BACKBONE_3D:
NAME: PointNet2Backbone
NAME: PointNet2MSG
SA_CONFIG:
NPOINTS: [4096, 1024, 256, 64]
RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment