Commit 43baf787 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

Merge branch 'dev_pointrcnn' into dev_v0.2.1

parents 85ff046d 8075b170
......@@ -62,6 +62,34 @@ class DataProcessor(object):
data_dict['voxel_num_points'] = num_points
return data_dict
def sample_points(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.sample_points, config=config)
num_points = config.NUM_POINTS[self.mode]
if num_points == -1:
return data_dict
points = data_dict['points']
if num_points < len(points):
pts_depth = np.linalg.norm(points[:, 0:3], axis=1)
pts_near_flag = pts_depth < 40.0
far_idxs_choice = np.where(pts_near_flag == 0)[0]
near_idxs = np.where(pts_near_flag == 1)[0]
near_idxs_choice = np.random.choice(near_idxs, num_points - len(far_idxs_choice), replace=False)
choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
if len(far_idxs_choice) > 0 else near_idxs_choice
np.random.shuffle(choice)
else:
choice = np.arange(0, len(points), dtype=np.int32)
if num_points > len(points):
extra_choice = np.random.choice(choice, num_points - len(points), replace=False)
choice = np.concatenate((choice, extra_choice), axis=0)
np.random.shuffle(choice)
data_dict['points'] = points[choice]
return data_dict
def forward(self, data_dict):
"""
Args:
......
from .spconv_backbone import VoxelBackBone8x
from .spconv_unet import UNetV2
from .pointnet2_backbone import PointNet2Backbone, PointNet2MSG
__all__ = {
'VoxelBackBone8x': VoxelBackBone8x,
'UNetV2': UNetV2
'UNetV2': UNetV2,
'PointNet2Backbone': PointNet2Backbone,
'PointNet2MSG': PointNet2MSG
}
import torch
import torch.nn as nn
from ...ops.pointnet2.pointnet2_batch import pointnet2_modules
from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_modules_stack
from ...ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_utils_stack
class PointNet2MSG(nn.Module):
def __init__(self, model_cfg, input_channels, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.SA_modules = nn.ModuleList()
channel_in = input_channels - 3
self.num_points_each_layer = []
skip_channel_list = [input_channels - 3]
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
pointnet2_modules.PointnetSAModuleMSG(
npoint=self.model_cfg.SA_CONFIG.NPOINTS[k],
radii=self.model_cfg.SA_CONFIG.RADIUS[k],
nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True),
)
)
skip_channel_list.append(channel_out)
channel_in = channel_out
self.FP_modules = nn.ModuleList()
for k in range(self.model_cfg.FP_MLPS.__len__()):
pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out
self.FP_modules.append(
pointnet2_modules.PointnetFPModule(
mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k]
)
)
self.num_point_features = self.model_cfg.FP_MLPS[0][-1]
def break_up_pc(self, pc):
batch_idx = pc[:, 0]
xyz = pc[:, 1:4].contiguous()
features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None)
return batch_idx, xyz, features
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
points: (num_points, 4 + C), [batch_idx, x, y, z, ...]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
batch_size = batch_dict['batch_size']
points = batch_dict['points']
batch_idx, xyz, features = self.break_up_pc(points)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
assert xyz_batch_cnt.min() == xyz_batch_cnt.max()
xyz = xyz.view(batch_size, -1, 3)
features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1) if features is not None else None
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
) # (B, C, N)
point_features = l_features[0].permute(0, 2, 1).contiguous() # (B, N, C)
batch_dict['point_features'] = point_features.view(-1, point_features.shape[-1])
batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0].view(-1, 3)), dim=1)
return batch_dict
class PointNet2Backbone(nn.Module):
"""
DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723
"""
def __init__(self, model_cfg, input_channels, **kwargs):
assert False, 'DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723'
super().__init__()
self.model_cfg = model_cfg
self.SA_modules = nn.ModuleList()
channel_in = input_channels - 3
self.num_points_each_layer = []
skip_channel_list = [input_channels]
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
self.num_points_each_layer.append(self.model_cfg.SA_CONFIG.NPOINTS[k])
mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
pointnet2_modules_stack.StackSAModuleMSG(
radii=self.model_cfg.SA_CONFIG.RADIUS[k],
nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True),
)
)
skip_channel_list.append(channel_out)
channel_in = channel_out
self.FP_modules = nn.ModuleList()
for k in range(self.model_cfg.FP_MLPS.__len__()):
pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out
self.FP_modules.append(
pointnet2_modules_stack.StackPointnetFPModule(
mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k]
)
)
self.num_point_features = self.model_cfg.FP_MLPS[0][-1]
def break_up_pc(self, pc):
batch_idx = pc[:, 0]
xyz = pc[:, 1:4].contiguous()
features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None)
return batch_idx, xyz, features
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
points: (num_points, 4 + C), [batch_idx, x, y, z, ...]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
batch_size = batch_dict['batch_size']
points = batch_dict['points']
batch_idx, xyz, features = self.break_up_pc(points)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
l_xyz, l_features, l_batch_cnt = [xyz], [features], [xyz_batch_cnt]
for i in range(len(self.SA_modules)):
new_xyz_list = []
for k in range(batch_size):
if len(l_xyz) == 1:
cur_xyz = l_xyz[0][batch_idx == k]
else:
last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]:
empty_num = self.num_points_each_layer[i] - cur_xyz.shape[1]
cur_pt_idxs[0, -empty_num:] = cur_pt_idxs[0, :empty_num]
new_xyz_list.append(cur_xyz[cur_pt_idxs])
new_xyz = torch.cat(new_xyz_list, dim=0)
new_xyz_batch_cnt = xyz.new_zeros(batch_size).int().fill_(self.num_points_each_layer[i])
li_xyz, li_features = self.SA_modules[i](
xyz=l_xyz[i], features=l_features[i], xyz_batch_cnt=l_batch_cnt[i],
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
)
l_xyz.append(li_xyz)
l_features.append(li_features)
l_batch_cnt.append(new_xyz_batch_cnt)
l_features[0] = points[:, 1:]
for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
unknown=l_xyz[i - 1], unknown_batch_cnt=l_batch_cnt[i - 1],
known=l_xyz[i], known_batch_cnt=l_batch_cnt[i],
unknown_feats=l_features[i - 1], known_feats=l_features[i]
)
batch_dict['point_features'] = l_features[0]
batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0]), dim=1)
return batch_dict
......@@ -91,8 +91,8 @@ class UNetV2(nn.Module):
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
)
last_pad = 0
last_pad = self.model_cfg.get('last_pad', last_pad)
if self.model_cfg.get('RETURN_ENCODED_TENSOR', True):
last_pad = self.model_cfg.get('last_pad', 0)
self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
......@@ -101,6 +101,8 @@ class UNetV2(nn.Module):
norm_fn(128),
nn.ReLU(),
)
else:
self.conv_out = None
# decoder
# [400, 352, 11] <- [200, 176, 5]
......@@ -181,9 +183,12 @@ class UNetV2(nn.Module):
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
if self.conv_out is not None:
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
# for segmentation head
# [400, 352, 11] <- [200, 176, 5]
......@@ -201,6 +206,4 @@ class UNetV2(nn.Module):
point_cloud_range=self.point_cloud_range
)
batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
return batch_dict
......@@ -2,6 +2,7 @@ from .anchor_head_template import AnchorHeadTemplate
from .anchor_head_single import AnchorHeadSingle
from .point_intra_part_head import PointIntraPartOffsetHead
from .point_head_simple import PointHeadSimple
from .point_head_box import PointHeadBox
from .anchor_head_multi import AnchorHeadMulti
__all__ = {
......@@ -9,5 +10,6 @@ __all__ = {
'AnchorHeadSingle': AnchorHeadSingle,
'PointIntraPartOffsetHead': PointIntraPartOffsetHead,
'PointHeadSimple': PointHeadSimple,
'PointHeadBox': PointHeadBox,
'AnchorHeadMulti': AnchorHeadMulti,
}
import torch
from .point_head_template import PointHeadTemplate
from ...utils import box_coder_utils, box_utils
class PointHeadBox(PointHeadTemplate):
"""
A simple point-based segmentation head, which are used for PointRCNN.
Reference Paper: https://arxiv.org/abs/1812.04244
PointRCNN: 3D Object Proposal Generation and Detection from Point Cloud
"""
def __init__(self, num_class, input_channels, model_cfg, predict_boxes_when_training=False, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class)
self.predict_boxes_when_training = predict_boxes_when_training
self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels,
output_channels=num_class
)
target_cfg = self.model_cfg.TARGET_CONFIG
self.box_coder = getattr(box_coder_utils, target_cfg.BOX_CODER)(
**target_cfg.BOX_CODER_CONFIG
)
self.box_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.REG_FC,
input_channels=input_channels,
output_channels=self.box_coder.code_size
)
def assign_targets(self, input_dict):
"""
Args:
input_dict:
point_features: (N1 + N2 + N3 + ..., C)
batch_size:
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes (optional): (B, M, 8)
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_part_labels: (N1 + N2 + N3 + ..., 3)
"""
point_coords = input_dict['point_coords']
gt_boxes = input_dict['gt_boxes']
assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)
batch_size = gt_boxes.shape[0]
extend_gt_boxes = box_utils.enlarge_box3d(
gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH
).view(batch_size, -1, gt_boxes.shape[-1])
targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=False, ret_box_labels=True
)
return targets_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss()
point_loss_box, tb_dict_2 = self.get_box_layer_loss()
point_loss = point_loss_cls + point_loss_box
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2)
return point_loss, tb_dict
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
point_features_before_fusion: (N1 + N2 + N3 + ..., C)
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
point_labels (optional): (N1 + N2 + N3 + ...)
gt_boxes (optional): (B, M, 8)
Returns:
batch_dict:
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
"""
if self.model_cfg.get('USE_POINT_FEATURES_BEFORE_FUSION', False):
point_features = batch_dict['point_features_before_fusion']
else:
point_features = batch_dict['point_features']
point_cls_preds = self.cls_layers(point_features) # (total_points, num_class)
point_box_preds = self.box_layers(point_features) # (total_points, box_code_size)
point_cls_preds_max, _ = point_cls_preds.max(dim=-1)
batch_dict['point_cls_scores'] = torch.sigmoid(point_cls_preds_max)
ret_dict = {'point_cls_preds': point_cls_preds,
'point_box_preds': point_box_preds}
if self.training:
targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_box_labels'] = targets_dict['point_box_labels']
if not self.training or self.predict_boxes_when_training:
point_cls_preds, point_box_preds = self.generate_predicted_boxes(
points=batch_dict['point_coords'][:, 1:4],
point_cls_preds=point_cls_preds, point_box_preds=point_box_preds
)
batch_dict['batch_cls_preds'] = point_cls_preds
batch_dict['batch_box_preds'] = point_box_preds
batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
batch_dict['cls_preds_normalized'] = False
self.forward_ret_dict = ret_dict
return batch_dict
......@@ -19,7 +19,17 @@ class PointHeadTemplate(nn.Module):
'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
)
self.reg_loss_func = F.smooth_l1_loss if losses_cfg.get('LOSS_REG', None) == 'smooth-l1' else F.l1_loss
reg_loss_type = losses_cfg.get('LOSS_REG', None)
if reg_loss_type == 'smooth-l1':
self.reg_loss_func = F.smooth_l1_loss
elif reg_loss_type == 'l1':
self.reg_loss_func = F.l1_loss
elif reg_loss_type == 'WeightedSmoothL1Loss':
self.reg_loss_func = loss_utils.WeightedSmoothL1Loss(
code_weights=losses_cfg.LOSS_WEIGHTS.get('code_weights', None)
)
else:
self.reg_loss_func = F.smooth_l1_loss
@staticmethod
def make_fc_layers(fc_cfg, input_channels, output_channels):
......@@ -88,11 +98,15 @@ class PointHeadTemplate(nn.Module):
raise NotImplementedError
gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, 7].long()
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
point_cls_labels[bs_mask] = point_cls_labels_single
if ret_box_labels:
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
fg_point_box_labels = self.box_coder.encode_torch(points_single[fg_flag], gt_box_of_fg_points)
fg_point_box_labels = self.box_coder.encode_torch(
gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
gt_classes=gt_box_of_fg_points[:, -1].long()
)
point_box_labels_single[fg_flag] = fg_point_box_labels
point_box_labels[bs_mask] = point_box_labels_single
......@@ -113,7 +127,7 @@ class PointHeadTemplate(nn.Module):
}
return targets_dict
def get_cls_layer_loss(self):
def get_cls_layer_loss(self, tb_dict=None):
point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1)
point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class)
......@@ -131,13 +145,15 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight']
tb_dict = {
if tb_dict is None:
tb_dict = {}
tb_dict.update({
'point_loss_cls': point_loss_cls.item(),
'point_pos_num': pos_normalizer.item()
}
})
return point_loss_cls, tb_dict
def get_part_layer_loss(self):
def get_part_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
pos_normalizer = max(1, (pos_mask > 0).sum().item())
point_part_labels = self.forward_ret_dict['point_part_labels']
......@@ -147,7 +163,47 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
return point_loss_part, {'point_loss_part': point_loss_part.item()}
if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_part': point_loss_part.item()})
return point_loss_part, tb_dict
def get_box_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
point_box_labels = self.forward_ret_dict['point_box_labels']
point_box_preds = self.forward_ret_dict['point_box_preds']
reg_weights = pos_mask.float()
pos_normalizer = pos_mask.sum().float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
point_loss_box_src = self.reg_loss_func(
point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...]
)
point_loss_box = point_loss_box_src.sum()
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_box = point_loss_box * loss_weights_dict['point_box_weight']
if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_box': point_loss_box.item()})
return point_loss_box, tb_dict
def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds):
"""
Args:
points: (N, 3)
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
Returns:
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
"""
_, pred_classes = point_cls_preds.max(dim=-1)
point_box_preds = self.box_coder.decode_torch(point_box_preds, points, pred_classes + 1)
return point_cls_preds, point_box_preds
def forward(self, **kwargs):
raise NotImplementedError
import torch
from .point_head_template import PointHeadTemplate
from ...utils import box_utils
from ...utils import box_utils, box_coder_utils
class PointIntraPartOffsetHead(PointHeadTemplate):
......@@ -9,8 +9,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
Reference Paper: https://arxiv.org/abs/1907.03670
From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
"""
def __init__(self, num_class, input_channels, model_cfg, **kwargs):
def __init__(self, num_class, input_channels, model_cfg, predict_boxes_when_training=False, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class)
self.predict_boxes_when_training = predict_boxes_when_training
self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels,
......@@ -21,6 +22,18 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
input_channels=input_channels,
output_channels=3
)
target_cfg = self.model_cfg.TARGET_CONFIG
if target_cfg.get('BOX_CODER', None) is not None:
self.box_coder = getattr(box_coder_utils, target_cfg.BOX_CODER)(
**target_cfg.BOX_CODER_CONFIG
)
self.box_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.REG_FC,
input_channels=input_channels,
output_channels=self.box_coder.code_size
)
else:
self.box_layers = None
def assign_targets(self, input_dict):
"""
......@@ -46,19 +59,20 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=True
ret_part_labels=True, ret_box_labels=(self.box_layers is not None)
)
return targets_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss()
point_loss_part, tb_dict_2 = self.get_part_layer_loss()
point_loss_cls, tb_dict = self.get_cls_layer_loss(tb_dict)
point_loss_part, tb_dict = self.get_part_layer_loss(tb_dict)
point_loss = point_loss_cls + point_loss_part
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2)
if self.box_layers is not None:
point_loss_box, tb_dict = self.get_box_layer_loss(tb_dict)
point_loss += point_loss_box
return point_loss, tb_dict
def forward(self, batch_dict):
......@@ -83,6 +97,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
'point_cls_preds': point_cls_preds,
'point_part_preds': point_part_preds,
}
if self.box_layers is not None:
point_box_preds = self.box_layers(point_features)
ret_dict['point_box_preds'] = point_box_preds
point_cls_scores = torch.sigmoid(point_cls_preds)
point_part_offset = torch.sigmoid(point_part_preds)
......@@ -93,6 +110,17 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_part_labels'] = targets_dict.get('point_part_labels')
self.forward_ret_dict = ret_dict
ret_dict['point_box_labels'] = targets_dict.get('point_box_labels')
if self.box_layers is not None and (not self.training or self.predict_boxes_when_training):
point_cls_preds, point_box_preds = self.generate_predicted_boxes(
points=batch_dict['point_coords'][:, 1:4],
point_cls_preds=point_cls_preds, point_box_preds=ret_dict['point_box_preds']
)
batch_dict['batch_cls_preds'] = point_cls_preds
batch_dict['batch_box_preds'] = point_box_preds
batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
batch_dict['cls_preds_normalized'] = False
self.forward_ret_dict = ret_dict
return batch_dict
......@@ -3,13 +3,15 @@ from .second_net import SECONDNet
from .PartA2_net import PartA2Net
from .pv_rcnn import PVRCNN
from .pointpillar import PointPillar
from .point_rcnn import PointRCNN
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
'SECONDNet': SECONDNet,
'PartA2Net': PartA2Net,
'PVRCNN': PVRCNN,
'PointPillar': PointPillar
'PointPillar': PointPillar,
'PointRCNN': PointRCNN
}
......
from .detector3d_template import Detector3DTemplate
class PointRCNN(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_point, tb_dict = self.point_head.get_loss()
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_point + loss_rcnn
return loss, tb_dict, disp_dict
from .roi_head_template import RoIHeadTemplate
from .partA2_head import PartA2FCHead
from .pvrcnn_head import PVRCNNHead
from .pointrcnn_head import PointRCNNHead
__all__ = {
'RoIHeadTemplate': RoIHeadTemplate,
'PartA2FCHead': PartA2FCHead,
'PVRCNNHead': PVRCNNHead
'PVRCNNHead': PVRCNNHead,
'PointRCNNHead': PointRCNNHead
}
......@@ -118,7 +118,8 @@ class PartA2FCHead(RoIHeadTemplate):
point_coords = batch_dict['point_coords'][:, 1:4]
point_features = batch_dict['point_features']
part_features = torch.cat((
batch_dict['point_part_offset'], batch_dict['point_cls_scores'].view(-1, 1).detach()
batch_dict['point_part_offset'] if not self.model_cfg.get('DISABLE_PART', False) else point_coords,
batch_dict['point_cls_scores'].view(-1, 1).detach()
), dim=1)
part_features[part_features[:, -1] < self.model_cfg.SEG_MASK_SCORE_THRESH, 0:3] = 0
......
import torch
import torch.nn as nn
from .roi_head_template import RoIHeadTemplate
from ...ops.pointnet2.pointnet2_batch import pointnet2_modules
from ...ops.roipoint_pool3d import roipoint_pool3d_utils
from ...utils import common_utils
class PointRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
use_bn = self.model_cfg.USE_BN
self.SA_modules = nn.ModuleList()
channel_in = input_channels
self.num_prefix_channels = 3 + 2 # xyz + point_scores + point_depth
xyz_mlps = [self.num_prefix_channels] + self.model_cfg.XYZ_UP_LAYER
shared_mlps = []
for k in range(len(xyz_mlps) - 1):
shared_mlps.append(nn.Conv2d(xyz_mlps[k], xyz_mlps[k + 1], kernel_size=1, bias=not use_bn))
if use_bn:
shared_mlps.append(nn.BatchNorm2d(xyz_mlps[k + 1]))
shared_mlps.append(nn.ReLU())
self.xyz_up_layer = nn.Sequential(*shared_mlps)
c_out = self.model_cfg.XYZ_UP_LAYER[-1]
self.merge_down_layer = nn.Sequential(
nn.Conv2d(c_out * 2, c_out, kernel_size=1, bias=not use_bn),
*[nn.BatchNorm2d(c_out), nn.ReLU()] if use_bn else [nn.ReLU()]
)
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
mlps = [channel_in] + self.model_cfg.SA_CONFIG.MLPS[k]
npoint = self.model_cfg.SA_CONFIG.NPOINTS[k] if self.model_cfg.SA_CONFIG.NPOINTS[k] != -1 else None
self.SA_modules.append(
pointnet2_modules.PointnetSAModule(
npoint=npoint,
radius=self.model_cfg.SA_CONFIG.RADIUS[k],
nsample=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlp=mlps,
use_xyz=True,
bn=use_bn
)
)
channel_in = mlps[-1]
self.cls_layers = self.make_fc_layers(
input_channels=channel_in, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
)
self.reg_layers = self.make_fc_layers(
input_channels=channel_in,
output_channels=self.box_coder.code_size * self.num_class,
fc_list=self.model_cfg.REG_FC
)
self.roipoint_pool3d_layer = roipoint_pool3d_utils.RoIPointPool3d(
num_sampled_points=self.model_cfg.ROI_POINT_POOL.NUM_SAMPLED_POINTS,
pool_extra_width=self.model_cfg.ROI_POINT_POOL.POOL_EXTRA_WIDTH
)
self.init_weights(weight_init='xavier')
def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
elif weight_init == 'xavier':
init_func = nn.init.xavier_normal_
elif weight_init == 'normal':
init_func = nn.init.normal_
else:
raise NotImplementedError
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
if weight_init == 'normal':
init_func(m.weight, mean=0, std=0.001)
else:
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
def roipool3d_gpu(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
point_coords: (num_points, 4) [bs_idx, x, y, z]
point_features: (num_points, C)
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
Returns:
"""
batch_size = batch_dict['batch_size']
batch_idx = batch_dict['point_coords'][:, 0]
point_coords = batch_dict['point_coords'][:, 1:4]
point_features = batch_dict['point_features']
rois = batch_dict['rois'] # (B, num_rois, 7 + C)
batch_cnt = point_coords.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
assert batch_cnt.min() == batch_cnt.max()
point_scores = batch_dict['point_cls_scores'].detach()
point_depths = point_coords.norm(dim=1) / self.model_cfg.ROI_POINT_POOL.DEPTH_NORMALIZER - 0.5
point_features_list = [point_scores[:, None], point_depths[:, None], point_features]
point_features_all = torch.cat(point_features_list, dim=1)
batch_points = point_coords.view(batch_size, -1, 3)
batch_point_features = point_features_all.view(batch_size, -1, point_features_all.shape[-1])
with torch.no_grad():
pooled_features, pooled_empty_flag = self.roipoint_pool3d_layer(
batch_points, batch_point_features, rois
) # pooled_features: (B, num_rois, num_sampled_points, 3 + C), pooled_empty_flag: (B, num_rois)
# canonical transformation
roi_center = rois[:, :, 0:3]
pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
pooled_features = pooled_features.view(-1, pooled_features.shape[-2], pooled_features.shape[-1])
pooled_features[:, :, 0:3] = common_utils.rotate_points_along_z(
pooled_features[:, :, 0:3], -rois.view(-1, rois.shape[-1])[:, 6]
)
pooled_features[pooled_empty_flag.view(-1) > 0] = 0
return pooled_features
def forward(self, batch_dict):
"""
Args:
batch_dict:
Returns:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
pooled_features = self.roipool3d_gpu(batch_dict) # (total_rois, num_sampled_points, 3 + C)
xyz_input = pooled_features[..., 0:self.num_prefix_channels].transpose(1, 2).unsqueeze(dim=3)
xyz_features = self.xyz_up_layer(xyz_input)
point_features = pooled_features[..., self.num_prefix_channels:].transpose(1, 2).unsqueeze(dim=3)
merged_features = torch.cat((xyz_features, point_features), dim=1)
merged_features = self.merge_down_layer(merged_features)
l_xyz, l_features = [pooled_features[..., 0:3].contiguous()], [merged_features.squeeze(dim=3).contiguous()]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
shared_features = l_features[-1] # (total_rois, num_features, 1)
rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
if not self.training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
)
batch_dict['batch_cls_preds'] = batch_cls_preds
batch_dict['batch_box_preds'] = batch_box_preds
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_cls'] = rcnn_cls
targets_dict['rcnn_reg'] = rcnn_reg
self.forward_ret_dict = targets_dict
return batch_dict
......@@ -92,6 +92,7 @@ class RoIHeadTemplate(nn.Module):
batch_dict['roi_scores'] = roi_scores
batch_dict['roi_labels'] = roi_labels + 1
batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False
batch_dict.pop('batch_index', None)
return batch_dict
def assign_targets(self, batch_dict):
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import pointnet2_utils
from typing import List
class _PointnetSAModuleBase(nn.Module):
def __init__(self):
super().__init__()
self.npoint = None
self.groupers = None
self.mlps = None
self.pool_method = 'max_pool'
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, N, C) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
if self.pool_method == 'max_pool':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
else:
raise NotImplementedError
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointnetSAModuleMSG(_PointnetSAModuleBase):
"""Pointnet set abstraction layer with multiscale grouping"""
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
use_xyz: bool = True, pool_method='max_pool'):
"""
:param npoint: int
:param radii: list of float, list of radii to group with
:param nsamples: list of int, number of samples in each ball query
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
shared_mlps = []
for k in range(len(mlp_spec) - 1):
shared_mlps.extend([
nn.Conv2d(mlp_spec[k], mlp_spec[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp_spec[k + 1]),
nn.ReLU()
])
self.mlps.append(nn.Sequential(*shared_mlps))
self.pool_method = pool_method
class PointnetSAModule(PointnetSAModuleMSG):
"""Pointnet set abstraction layer"""
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
bn: bool = True, use_xyz: bool = True, pool_method='max_pool'):
"""
:param mlp: list of int, spec of the pointnet before the global max_pool
:param npoint: int, number of features
:param radius: float, radius of ball
:param nsample: int, number of samples in the ball query
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
"""
super().__init__(
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
pool_method=pool_method
)
class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another"""
def __init__(self, *, mlp: List[int], bn: bool = True):
"""
:param mlp: list of int
:param bn: whether to use batchnorm
"""
super().__init__()
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
def forward(
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
) -> torch.Tensor:
"""
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
:param known: (B, m, 3) tensor of the xyz positions of the known features
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
:param known_feats: (B, C2, m) tensor of features to be propigated
:return:
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
"""
if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
else:
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
if unknow_feats is not None:
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)
if __name__ == "__main__":
pass
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from typing import Tuple
from . import pointnet2_batch_cuda as pointnet2
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param npoint: int, number of features in the sampled set
:return:
output: (B, npoint) tensor containing the set
"""
assert xyz.is_contiguous()
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N)
:param idx: (B, npoint) index tensor of the features to gather
:return:
output: (B, C, npoint)
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, npoint = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
ctx.for_backwards = (idx, C, N)
return output
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
return grad_features, None
gather_operation = GatherOperation.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, 3) l2 distance to the three nearest neighbors
idx: (B, N, 3) index of 3 nearest neighbors
"""
assert unknown.is_contiguous()
assert known.is_contiguous()
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Performs weight linear interpolation on 3 features
:param ctx:
:param features: (B, C, M) Features descriptors to be interpolated from
:param idx: (B, n, 3) three nearest neighbors of the target features in features
:param weight: (B, n, 3) weights
:return:
output: (B, C, N) tensor of the interpolated features
"""
assert features.is_contiguous()
assert idx.is_contiguous()
assert weight.is_contiguous()
B, c, m = features.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, c, n)
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, N) tensor with gradients of outputs
:return:
grad_features: (B, C, M) tensor with gradients of features
None:
None:
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
return grad_features, None, None
three_interpolate = ThreeInterpolate.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N) tensor of features to group
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
:return:
output: (B, C, npoint, nsample) tensor
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, nfeatures, nsample = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
ctx.for_backwards = (idx, N)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
:return:
grad_features: (B, C, N) gradient of the features
"""
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
return grad_features, None
grouping_operation = GroupingOperation.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
return new_features
class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features
/*
batch version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
const float *new_xyz = new_xyz_tensor.data<float>();
const float *xyz = xyz_tensor.data<float>();
int *idx = idx_tensor.data<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
return 1;
}
\ No newline at end of file
/*
batch version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_gpu.h"
#include "cuda_utils.h"
__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
float radius2 = radius * radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int cnt = 0;
for (int k = 0; k < n; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
}
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t err;
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
#ifndef _BALL_QUERY_GPU_H
#define _BALL_QUERY_GPU_H
#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
#endif
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#include <cmath>
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, TOTAL_THREADS), 1);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment