"dynemo.code-workspace" did not exist on "695754aeb1e541839c38da7658fb8d71dbc27ae9"
Unverified Commit 32567b04 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Merge pull request #192 from sshaoshuai/master

Release OpenPCDet v0.3.0
parents 853b759b 04e0d4f0
import torch
import numpy as np import numpy as np
from ....utils import box_utils import torch
from ....ops.iou3d_nms import iou3d_nms_utils from ....ops.iou3d_nms import iou3d_nms_utils
from ....utils import box_utils
class AxisAlignedTargetAssigner(object): class AxisAlignedTargetAssigner(object):
def __init__(self, anchor_target_cfg, anchor_generator_cfg, class_names, box_coder, match_height=False): def __init__(self, model_cfg, class_names, box_coder, match_height=False):
super().__init__() super().__init__()
anchor_generator_cfg = model_cfg.ANCHOR_GENERATOR_CONFIG
anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = box_coder self.box_coder = box_coder
self.match_height = match_height self.match_height = match_height
self.class_names = np.array(class_names) self.class_names = np.array(class_names)
...@@ -19,8 +23,17 @@ class AxisAlignedTargetAssigner(object): ...@@ -19,8 +23,17 @@ class AxisAlignedTargetAssigner(object):
for config in anchor_generator_cfg: for config in anchor_generator_cfg:
self.matched_thresholds[config['class_name']] = config['matched_threshold'] self.matched_thresholds[config['class_name']] = config['matched_threshold']
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold'] self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False): self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
self.seperate_multihead = model_cfg.get('SEPERATE_MULTIHEAD', False)
if self.seperate_multihead:
rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
self.gt_remapping = {}
for rpn_head_cfg in rpn_head_cfgs:
for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
self.gt_remapping[name] = idx + 1
def assign_targets(self, all_anchors, gt_boxes_with_classes):
""" """
Args: Args:
all_anchors: [(N, 7), ...] all_anchors: [(N, 7), ...]
...@@ -30,13 +43,12 @@ class AxisAlignedTargetAssigner(object): ...@@ -30,13 +43,12 @@ class AxisAlignedTargetAssigner(object):
""" """
bbox_targets = [] bbox_targets = []
bbox_src_targets = [] cls_labels = []
cls_labels = []
reg_weights = [] reg_weights = []
batch_size = gt_boxes_with_classes.shape[0] batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7] gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :7] gt_boxes = gt_boxes_with_classes[:, :, :-1]
for k in range(batch_size): for k in range(batch_size):
cur_gt = gt_boxes[k] cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1 cnt = cur_gt.__len__() - 1
...@@ -53,27 +65,36 @@ class AxisAlignedTargetAssigner(object): ...@@ -53,27 +65,36 @@ class AxisAlignedTargetAssigner(object):
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
for c in cur_gt_classes], dtype=torch.bool) for c in cur_gt_classes], dtype=torch.bool)
if use_multihead: if self.use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
if self.seperate_multihead:
selected_classes = cur_gt_classes[mask].clone()
if len(selected_classes) > 0:
new_cls_id = self.gt_remapping[anchor_class_name]
selected_classes[:] = new_cls_id
else:
selected_classes = cur_gt_classes[mask]
else: else:
feature_map_size = anchors.shape[:3] feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1]) anchors = anchors.view(-1, anchors.shape[-1])
selected_classes = cur_gt_classes[mask]
single_target = self.assign_targets_single( single_target = self.assign_targets_single(
anchors, anchors,
cur_gt[mask], cur_gt[mask],
gt_classes=cur_gt_classes[mask], gt_classes=selected_classes,
matched_threshold=self.matched_thresholds[anchor_class_name], matched_threshold=self.matched_thresholds[anchor_class_name],
unmatched_threshold=self.unmatched_thresholds[anchor_class_name] unmatched_threshold=self.unmatched_thresholds[anchor_class_name]
) )
target_list.append(single_target) target_list.append(single_target)
if use_multihead:
if self.use_multihead:
target_dict = { target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list], 'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list], 'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(-1) for t in target_list] 'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
} }
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0) target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1) target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1) target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
...@@ -84,18 +105,19 @@ class AxisAlignedTargetAssigner(object): ...@@ -84,18 +105,19 @@ class AxisAlignedTargetAssigner(object):
for t in target_list], for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list] 'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
} }
target_dict['box_reg_targets'] = torch.cat(
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], target_dict['box_reg_targets'], dim=-2
dim=-2).view(-1, self.box_coder.code_size) ).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1) target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1) target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
bbox_targets.append(target_dict['box_reg_targets']) bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels']) cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights']) reg_weights.append(target_dict['reg_weights'])
bbox_targets = torch.stack(bbox_targets, dim=0) bbox_targets = torch.stack(bbox_targets, dim=0)
cls_labels = torch.stack(cls_labels, dim=0) cls_labels = torch.stack(cls_labels, dim=0)
reg_weights = torch.stack(reg_weights, dim=0) reg_weights = torch.stack(reg_weights, dim=0)
all_targets_dict = { all_targets_dict = {
...@@ -115,11 +137,10 @@ class AxisAlignedTargetAssigner(object): ...@@ -115,11 +137,10 @@ class AxisAlignedTargetAssigner(object):
num_anchors = anchors.shape[0] num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0] num_gt = gt_boxes.shape[0]
# box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0: if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \ anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7]) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
...@@ -133,29 +154,29 @@ class AxisAlignedTargetAssigner(object): ...@@ -133,29 +154,29 @@ class AxisAlignedTargetAssigner(object):
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)] gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
empty_gt_mask = gt_to_anchor_max == 0 empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1 gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero(anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
anchors_with_max_overlap = (anchor_by_gt_overlap == gt_to_anchor_max).nonzero()[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap] gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
gt_ids[anchors_with_max_overlap] = gt_inds_force.int() gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
pos_inds = anchor_to_gt_max >= matched_threshold pos_inds = anchor_to_gt_max >= matched_threshold
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds] gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
labels[pos_inds] = gt_classes[gt_inds_over_thresh] labels[pos_inds] = gt_classes[gt_inds_over_thresh]
gt_ids[pos_inds] = gt_inds_over_thresh.int() gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0] bg_inds = (anchor_to_gt_max < unmatched_threshold).nonzero()[:, 0]
else: else:
bg_inds = torch.arange(num_anchors, device=anchors.device) bg_inds = torch.arange(num_anchors, device=anchors.device)
fg_inds = torch.nonzero(labels > 0)[:, 0] fg_inds = (labels > 0).nonzero()[:, 0]
if self.pos_fraction is not None: if self.pos_fraction is not None:
num_fg = int(self.pos_fraction * self.sample_size) num_fg = int(self.pos_fraction * self.sample_size)
if len(fg_inds) > num_fg: if len(fg_inds) > num_fg:
num_disabled = len(fg_inds) - num_fg num_disabled = len(fg_inds) - num_fg
disable_inds = torch.randperm(len(fg_inds))[:num_disabled] disable_inds = torch.randperm(len(fg_inds))[:num_disabled]
labels[disable_inds] = -1 labels[disable_inds] = -1
fg_inds = torch.nonzero(labels > 0)[:, 0] fg_inds = (labels > 0).nonzero()[:, 0]
num_bg = self.sample_size - (labels > 0).sum() num_bg = self.sample_size - (labels > 0).sum()
if len(bg_inds) > num_bg: if len(bg_inds) > num_bg:
...@@ -176,7 +197,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -176,7 +197,7 @@ class AxisAlignedTargetAssigner(object):
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors) bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
reg_weights = anchors.new_zeros((num_anchors,)) reg_weights = anchors.new_zeros((num_anchors,))
if self.norm_by_num_examples: if self.norm_by_num_examples:
num_examples = (labels >= 0).sum() num_examples = (labels >= 0).sum()
num_examples = num_examples if num_examples > 1.0 else 1.0 num_examples = num_examples if num_examples > 1.0 else 1.0
...@@ -190,5 +211,3 @@ class AxisAlignedTargetAssigner(object): ...@@ -190,5 +211,3 @@ class AxisAlignedTargetAssigner(object):
'reg_weights': reg_weights, 'reg_weights': reg_weights,
} }
return ret_dict return ret_dict
from .detector3d_template import Detector3DTemplate from .detector3d_template import Detector3DTemplate
from .second_net import SECONDNet
from .PartA2_net import PartA2Net from .PartA2_net import PartA2Net
from .pv_rcnn import PVRCNN from .point_rcnn import PointRCNN
from .pointpillar import PointPillar from .pointpillar import PointPillar
from .pv_rcnn import PVRCNN
from .second_net import SECONDNet
__all__ = { __all__ = {
'Detector3DTemplate': Detector3DTemplate, 'Detector3DTemplate': Detector3DTemplate,
'SECONDNet': SECONDNet, 'SECONDNet': SECONDNet,
'PartA2Net': PartA2Net, 'PartA2Net': PartA2Net,
'PVRCNN': PVRCNN, 'PVRCNN': PVRCNN,
'PointPillar': PointPillar 'PointPillar': PointPillar,
'PointRCNN': PointRCNN
} }
......
import torch
import os import os
import torch
import torch.nn as nn import torch.nn as nn
from .. import backbones_3d, backbones_2d, dense_heads, roi_heads
from ..backbones_3d import vfe, pfe
from ..backbones_2d import map_to_bev
from ..model_utils.model_nms_utils import class_agnostic_nms
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from .. import backbones_2d, backbones_3d, dense_heads, roi_heads
from ..backbones_2d import map_to_bev
from ..backbones_3d import pfe, vfe
from ..model_utils import model_nms_utils
class Detector3DTemplate(nn.Module): class Detector3DTemplate(nn.Module):
...@@ -169,10 +171,14 @@ class Detector3DTemplate(nn.Module): ...@@ -169,10 +171,14 @@ class Detector3DTemplate(nn.Module):
batch_dict: batch_dict:
batch_size: batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1) batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...]
multihead_label_mapping: [(num_class1), (num_class2), ...]
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C) batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...) batch_index: optional (N1+N2+...)
has_class_labels: True/False
roi_labels: (B, num_rois) 1 .. num_classes roi_labels: (B, num_rois) 1 .. num_classes
batch_pred_labels: (B, num_boxes, 1)
Returns: Returns:
""" """
...@@ -182,29 +188,63 @@ class Detector3DTemplate(nn.Module): ...@@ -182,29 +188,63 @@ class Detector3DTemplate(nn.Module):
pred_dicts = [] pred_dicts = []
for index in range(batch_size): for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None: if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_cls_preds'].shape.__len__() == 2 assert batch_dict['batch_box_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index) batch_mask = (batch_dict['batch_index'] == index)
else: else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3 assert batch_dict['batch_box_preds'].shape.__len__() == 3
batch_mask = index batch_mask = index
box_preds = batch_dict['batch_box_preds'][batch_mask] box_preds = batch_dict['batch_box_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
src_box_preds = box_preds src_box_preds = box_preds
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']: if not isinstance(batch_dict['batch_cls_preds'], list):
cls_preds = torch.sigmoid(cls_preds) cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
else:
cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']]
src_cls_preds = cls_preds
if not batch_dict['cls_preds_normalized']:
cls_preds = [torch.sigmoid(x) for x in cls_preds]
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS: if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError if not isinstance(cls_preds, list):
cls_preds = [cls_preds]
multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)]
else:
multihead_label_mapping = batch_dict['multihead_label_mapping']
cur_start_idx = 0
pred_scores, pred_labels, pred_boxes = [], [], []
for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping):
assert cur_cls_preds.shape[1] == len(cur_label_mapping)
cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]]
cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms(
cls_scores=cur_cls_preds, box_preds=cur_box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)
cur_pred_labels = cur_label_mapping[cur_pred_labels]
pred_scores.append(cur_pred_scores)
pred_labels.append(cur_pred_labels)
pred_boxes.append(cur_pred_boxes)
cur_start_idx += cur_cls_preds.shape[0]
final_scores = torch.cat(pred_scores, dim=0)
final_labels = torch.cat(pred_labels, dim=0)
final_boxes = torch.cat(pred_boxes, dim=0)
else: else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1) cls_preds, label_preds = torch.max(cls_preds, dim=-1)
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1 if batch_dict.get('has_class_labels', False):
label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
selected, selected_scores = class_agnostic_nms( label_preds = batch_dict[label_key][index]
else:
label_preds = label_preds + 1
selected, selected_scores = model_nms_utils.class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds, box_scores=cls_preds, box_preds=box_preds,
nms_config=post_process_cfg.NMS_CONFIG, nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH score_thresh=post_process_cfg.SCORE_THRESH
...@@ -253,14 +293,14 @@ class Detector3DTemplate(nn.Module): ...@@ -253,14 +293,14 @@ class Detector3DTemplate(nn.Module):
k -= 1 k -= 1
cur_gt = cur_gt[:k + 1] cur_gt = cur_gt[:k + 1]
if cur_gt.sum() > 0: if cur_gt.shape[0] > 0:
if box_preds.shape[0] > 0: if box_preds.shape[0] > 0:
iou3d_rcnn = iou3d_nms_utils.boxes_iou3d_gpu(box_preds, cur_gt[:, 0:7]) iou3d_rcnn = iou3d_nms_utils.boxes_iou3d_gpu(box_preds[:, 0:7], cur_gt[:, 0:7])
else: else:
iou3d_rcnn = torch.zeros((0, cur_gt.shape[0])) iou3d_rcnn = torch.zeros((0, cur_gt.shape[0]))
if rois is not None: if rois is not None:
iou3d_roi = iou3d_nms_utils.boxes_iou3d_gpu(rois, cur_gt[:, 0:7]) iou3d_roi = iou3d_nms_utils.boxes_iou3d_gpu(rois[:, 0:7], cur_gt[:, 0:7])
for cur_thresh in thresh_list: for cur_thresh in thresh_list:
if iou3d_rcnn.shape[0] == 0: if iou3d_rcnn.shape[0] == 0:
......
from .detector3d_template import Detector3DTemplate
class PointRCNN(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_point, tb_dict = self.point_head.get_loss()
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_point + loss_rcnn
return loss, tb_dict, disp_dict
import torch import torch
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
...@@ -14,7 +15,7 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None): ...@@ -14,7 +15,7 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0])) box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = box_preds[indices] boxes_for_nms = box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)( keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms, box_scores_nms, nms_config.NMS_THRESH, **nms_config boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
) )
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]] selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]
...@@ -22,3 +23,43 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None): ...@@ -22,3 +23,43 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
original_idxs = scores_mask.nonzero().view(-1) original_idxs = scores_mask.nonzero().view(-1)
selected = original_idxs[selected] selected = original_idxs[selected]
return selected, src_box_scores[selected] return selected, src_box_scores[selected]
def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None):
"""
Args:
cls_scores: (N, num_class)
box_preds: (N, 7 + C)
nms_config:
score_thresh:
Returns:
"""
pred_scores, pred_labels, pred_boxes = [], [], []
for k in range(cls_scores.shape[1]):
if score_thresh is not None:
scores_mask = (cls_scores[:, k] >= score_thresh)
box_scores = cls_scores[scores_mask, k]
cur_box_preds = box_preds[scores_mask]
else:
box_scores = cls_scores[:, k]
selected = []
if box_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = cur_box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
)
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]
pred_scores.append(box_scores[selected])
pred_labels.append(box_scores.new_ones(len(selected)).long() * k)
pred_boxes.append(cur_box_preds[selected])
pred_scores = torch.cat(pred_scores, dim=0)
pred_labels = torch.cat(pred_labels, dim=0)
pred_boxes = torch.cat(pred_boxes, dim=0)
return pred_scores, pred_labels, pred_boxes
from .roi_head_template import RoIHeadTemplate
from .partA2_head import PartA2FCHead from .partA2_head import PartA2FCHead
from .pointrcnn_head import PointRCNNHead
from .pvrcnn_head import PVRCNNHead from .pvrcnn_head import PVRCNNHead
from .roi_head_template import RoIHeadTemplate
__all__ = { __all__ = {
'RoIHeadTemplate': RoIHeadTemplate, 'RoIHeadTemplate': RoIHeadTemplate,
'PartA2FCHead': PartA2FCHead, 'PartA2FCHead': PartA2FCHead,
'PVRCNNHead': PVRCNNHead 'PVRCNNHead': PVRCNNHead,
'PointRCNNHead': PointRCNNHead
} }
import numpy as np
import spconv
import torch import torch
import torch.nn as nn import torch.nn as nn
import spconv
import numpy as np
from .roi_head_template import RoIHeadTemplate
from ...ops.roiaware_pool3d import roiaware_pool3d_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from .roi_head_template import RoIHeadTemplate
class PartA2FCHead(RoIHeadTemplate): class PartA2FCHead(RoIHeadTemplate):
...@@ -118,7 +119,8 @@ class PartA2FCHead(RoIHeadTemplate): ...@@ -118,7 +119,8 @@ class PartA2FCHead(RoIHeadTemplate):
point_coords = batch_dict['point_coords'][:, 1:4] point_coords = batch_dict['point_coords'][:, 1:4]
point_features = batch_dict['point_features'] point_features = batch_dict['point_features']
part_features = torch.cat(( part_features = torch.cat((
batch_dict['point_part_offset'], batch_dict['point_cls_scores'].view(-1, 1).detach() batch_dict['point_part_offset'] if not self.model_cfg.get('DISABLE_PART', False) else point_coords,
batch_dict['point_cls_scores'].view(-1, 1).detach()
), dim=1) ), dim=1)
part_features[part_features[:, -1] < self.model_cfg.SEG_MASK_SCORE_THRESH, 0:3] = 0 part_features[part_features[:, -1] < self.model_cfg.SEG_MASK_SCORE_THRESH, 0:3] = 0
......
import torch
import torch.nn as nn
from ...ops.pointnet2.pointnet2_batch import pointnet2_modules
from ...ops.roipoint_pool3d import roipoint_pool3d_utils
from ...utils import common_utils
from .roi_head_template import RoIHeadTemplate
class PointRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
use_bn = self.model_cfg.USE_BN
self.SA_modules = nn.ModuleList()
channel_in = input_channels
self.num_prefix_channels = 3 + 2 # xyz + point_scores + point_depth
xyz_mlps = [self.num_prefix_channels] + self.model_cfg.XYZ_UP_LAYER
shared_mlps = []
for k in range(len(xyz_mlps) - 1):
shared_mlps.append(nn.Conv2d(xyz_mlps[k], xyz_mlps[k + 1], kernel_size=1, bias=not use_bn))
if use_bn:
shared_mlps.append(nn.BatchNorm2d(xyz_mlps[k + 1]))
shared_mlps.append(nn.ReLU())
self.xyz_up_layer = nn.Sequential(*shared_mlps)
c_out = self.model_cfg.XYZ_UP_LAYER[-1]
self.merge_down_layer = nn.Sequential(
nn.Conv2d(c_out * 2, c_out, kernel_size=1, bias=not use_bn),
*[nn.BatchNorm2d(c_out), nn.ReLU()] if use_bn else [nn.ReLU()]
)
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
mlps = [channel_in] + self.model_cfg.SA_CONFIG.MLPS[k]
npoint = self.model_cfg.SA_CONFIG.NPOINTS[k] if self.model_cfg.SA_CONFIG.NPOINTS[k] != -1 else None
self.SA_modules.append(
pointnet2_modules.PointnetSAModule(
npoint=npoint,
radius=self.model_cfg.SA_CONFIG.RADIUS[k],
nsample=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlp=mlps,
use_xyz=True,
bn=use_bn
)
)
channel_in = mlps[-1]
self.cls_layers = self.make_fc_layers(
input_channels=channel_in, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
)
self.reg_layers = self.make_fc_layers(
input_channels=channel_in,
output_channels=self.box_coder.code_size * self.num_class,
fc_list=self.model_cfg.REG_FC
)
self.roipoint_pool3d_layer = roipoint_pool3d_utils.RoIPointPool3d(
num_sampled_points=self.model_cfg.ROI_POINT_POOL.NUM_SAMPLED_POINTS,
pool_extra_width=self.model_cfg.ROI_POINT_POOL.POOL_EXTRA_WIDTH
)
self.init_weights(weight_init='xavier')
def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
elif weight_init == 'xavier':
init_func = nn.init.xavier_normal_
elif weight_init == 'normal':
init_func = nn.init.normal_
else:
raise NotImplementedError
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
if weight_init == 'normal':
init_func(m.weight, mean=0, std=0.001)
else:
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
def roipool3d_gpu(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
point_coords: (num_points, 4) [bs_idx, x, y, z]
point_features: (num_points, C)
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
Returns:
"""
batch_size = batch_dict['batch_size']
batch_idx = batch_dict['point_coords'][:, 0]
point_coords = batch_dict['point_coords'][:, 1:4]
point_features = batch_dict['point_features']
rois = batch_dict['rois'] # (B, num_rois, 7 + C)
batch_cnt = point_coords.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
assert batch_cnt.min() == batch_cnt.max()
point_scores = batch_dict['point_cls_scores'].detach()
point_depths = point_coords.norm(dim=1) / self.model_cfg.ROI_POINT_POOL.DEPTH_NORMALIZER - 0.5
point_features_list = [point_scores[:, None], point_depths[:, None], point_features]
point_features_all = torch.cat(point_features_list, dim=1)
batch_points = point_coords.view(batch_size, -1, 3)
batch_point_features = point_features_all.view(batch_size, -1, point_features_all.shape[-1])
with torch.no_grad():
pooled_features, pooled_empty_flag = self.roipoint_pool3d_layer(
batch_points, batch_point_features, rois
) # pooled_features: (B, num_rois, num_sampled_points, 3 + C), pooled_empty_flag: (B, num_rois)
# canonical transformation
roi_center = rois[:, :, 0:3]
pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
pooled_features = pooled_features.view(-1, pooled_features.shape[-2], pooled_features.shape[-1])
pooled_features[:, :, 0:3] = common_utils.rotate_points_along_z(
pooled_features[:, :, 0:3], -rois.view(-1, rois.shape[-1])[:, 6]
)
pooled_features[pooled_empty_flag.view(-1) > 0] = 0
return pooled_features
def forward(self, batch_dict):
"""
Args:
batch_dict:
Returns:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
pooled_features = self.roipool3d_gpu(batch_dict) # (total_rois, num_sampled_points, 3 + C)
xyz_input = pooled_features[..., 0:self.num_prefix_channels].transpose(1, 2).unsqueeze(dim=3)
xyz_features = self.xyz_up_layer(xyz_input)
point_features = pooled_features[..., self.num_prefix_channels:].transpose(1, 2).unsqueeze(dim=3)
merged_features = torch.cat((xyz_features, point_features), dim=1)
merged_features = self.merge_down_layer(merged_features)
l_xyz, l_features = [pooled_features[..., 0:3].contiguous()], [merged_features.squeeze(dim=3).contiguous()]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
shared_features = l_features[-1] # (total_rois, num_features, 1)
rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
if not self.training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
)
batch_dict['batch_cls_preds'] = batch_cls_preds
batch_dict['batch_box_preds'] = batch_box_preds
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_cls'] = rcnn_cls
targets_dict['rcnn_reg'] = rcnn_reg
self.forward_ret_dict = targets_dict
return batch_dict
import torch.nn as nn import torch.nn as nn
from .roi_head_template import RoIHeadTemplate
from ...utils import common_utils
from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules
from ...utils import common_utils
from .roi_head_template import RoIHeadTemplate
class PVRCNNHead(RoIHeadTemplate): class PVRCNNHead(RoIHeadTemplate):
......
...@@ -2,9 +2,10 @@ import numpy as np ...@@ -2,9 +2,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .target_assigner.proposal_target_layer import ProposalTargetLayer
from ...utils import box_coder_utils, common_utils, loss_utils
from ..model_utils.model_nms_utils import class_agnostic_nms from ..model_utils.model_nms_utils import class_agnostic_nms
from ...utils import common_utils, loss_utils, box_coder_utils from .target_assigner.proposal_target_layer import ProposalTargetLayer
class RoIHeadTemplate(nn.Module): class RoIHeadTemplate(nn.Module):
...@@ -12,7 +13,9 @@ class RoIHeadTemplate(nn.Module): ...@@ -12,7 +13,9 @@ class RoIHeadTemplate(nn.Module):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.num_class = num_class self.num_class = num_class
self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)() self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)(
**self.model_cfg.TARGET_CONFIG.get('BOX_CODER_CONFIG', {})
)
self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG) self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG)
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None self.forward_ret_dict = None
...@@ -92,6 +95,7 @@ class RoIHeadTemplate(nn.Module): ...@@ -92,6 +95,7 @@ class RoIHeadTemplate(nn.Module):
batch_dict['roi_scores'] = roi_scores batch_dict['roi_scores'] = roi_scores
batch_dict['roi_labels'] = roi_labels + 1 batch_dict['roi_labels'] = roi_labels + 1
batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False
batch_dict.pop('batch_index', None)
return batch_dict return batch_dict
def assign_targets(self, batch_dict): def assign_targets(self, batch_dict):
...@@ -252,4 +256,3 @@ class RoIHeadTemplate(nn.Module): ...@@ -252,4 +256,3 @@ class RoIHeadTemplate(nn.Module):
batch_box_preds[:, 0:3] += roi_xyz batch_box_preds[:, 0:3] += roi_xyz
batch_box_preds = batch_box_preds.view(batch_size, -1, code_size) batch_box_preds = batch_box_preds.view(batch_size, -1, code_size)
return batch_cls_preds, batch_box_preds return batch_cls_preds, batch_box_preds
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ....ops.iou3d_nms import iou3d_nms_utils from ....ops.iou3d_nms import iou3d_nms_utils
...@@ -118,10 +119,10 @@ class ProposalTargetLayer(nn.Module): ...@@ -118,10 +119,10 @@ class ProposalTargetLayer(nn.Module):
fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE)) fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE))
fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH) fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH)
fg_inds = torch.nonzero((max_overlaps >= fg_thresh)).view(-1) fg_inds = ((max_overlaps >= fg_thresh)).nonzero().view(-1)
easy_bg_inds = torch.nonzero((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).view(-1) easy_bg_inds = ((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1)
hard_bg_inds = torch.nonzero((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) & hard_bg_inds = ((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) &
(max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).view(-1) (max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1)
fg_num_rois = fg_inds.numel() fg_num_rois = fg_inds.numel()
bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel() bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel()
......
...@@ -4,8 +4,9 @@ Written by Shaoshuai Shi ...@@ -4,8 +4,9 @@ Written by Shaoshuai Shi
All Rights Reserved 2019-2020. All Rights Reserved 2019-2020.
""" """
import torch import torch
from . import iou3d_nms_cuda
from ...utils import common_utils from ...utils import common_utils
from . import iou3d_nms_cuda
def boxes_bev_iou_cpu(boxes_a, boxes_b): def boxes_bev_iou_cpu(boxes_a, boxes_b):
......
...@@ -13,7 +13,19 @@ All Rights Reserved 2020. ...@@ -13,7 +13,19 @@ All Rights Reserved 2020.
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "iou3d_cpu.h" #include "iou3d_cpu.h"
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
inline float min(float a, float b){ inline float min(float a, float b){
return a > b ? b : a; return a > b ? b : a;
......
...@@ -11,8 +11,18 @@ All Rights Reserved 2019-2020. ...@@ -11,8 +11,18 @@ All Rights Reserved 2019-2020.
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "iou3d_nms.h" #include "iou3d_nms.h"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CUDA(x) do { \
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
...@@ -80,7 +90,6 @@ int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou ...@@ -80,7 +90,6 @@ int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou
int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){ int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params keep: (N) // params keep: (N)
CHECK_INPUT(boxes); CHECK_INPUT(boxes);
CHECK_CONTIGUOUS(keep); CHECK_CONTIGUOUS(keep);
......
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import pointnet2_utils
class _PointnetSAModuleBase(nn.Module):
def __init__(self):
super().__init__()
self.npoint = None
self.groupers = None
self.mlps = None
self.pool_method = 'max_pool'
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, N, C) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
if self.pool_method == 'max_pool':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
else:
raise NotImplementedError
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointnetSAModuleMSG(_PointnetSAModuleBase):
"""Pointnet set abstraction layer with multiscale grouping"""
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
use_xyz: bool = True, pool_method='max_pool'):
"""
:param npoint: int
:param radii: list of float, list of radii to group with
:param nsamples: list of int, number of samples in each ball query
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
shared_mlps = []
for k in range(len(mlp_spec) - 1):
shared_mlps.extend([
nn.Conv2d(mlp_spec[k], mlp_spec[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp_spec[k + 1]),
nn.ReLU()
])
self.mlps.append(nn.Sequential(*shared_mlps))
self.pool_method = pool_method
class PointnetSAModule(PointnetSAModuleMSG):
"""Pointnet set abstraction layer"""
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
bn: bool = True, use_xyz: bool = True, pool_method='max_pool'):
"""
:param mlp: list of int, spec of the pointnet before the global max_pool
:param npoint: int, number of features
:param radius: float, radius of ball
:param nsample: int, number of samples in the ball query
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
"""
super().__init__(
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
pool_method=pool_method
)
class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another"""
def __init__(self, *, mlp: List[int], bn: bool = True):
"""
:param mlp: list of int
:param bn: whether to use batchnorm
"""
super().__init__()
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
def forward(
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
) -> torch.Tensor:
"""
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
:param known: (B, m, 3) tensor of the xyz positions of the known features
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
:param known_feats: (B, C2, m) tensor of features to be propigated
:return:
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
"""
if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
else:
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
if unknow_feats is not None:
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)
if __name__ == "__main__":
pass
from typing import Tuple
import torch
import torch.nn as nn
from torch.autograd import Function, Variable
from . import pointnet2_batch_cuda as pointnet2
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param npoint: int, number of features in the sampled set
:return:
output: (B, npoint) tensor containing the set
"""
assert xyz.is_contiguous()
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N)
:param idx: (B, npoint) index tensor of the features to gather
:return:
output: (B, C, npoint)
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, npoint = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
ctx.for_backwards = (idx, C, N)
return output
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
return grad_features, None
gather_operation = GatherOperation.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, 3) l2 distance to the three nearest neighbors
idx: (B, N, 3) index of 3 nearest neighbors
"""
assert unknown.is_contiguous()
assert known.is_contiguous()
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Performs weight linear interpolation on 3 features
:param ctx:
:param features: (B, C, M) Features descriptors to be interpolated from
:param idx: (B, n, 3) three nearest neighbors of the target features in features
:param weight: (B, n, 3) weights
:return:
output: (B, C, N) tensor of the interpolated features
"""
assert features.is_contiguous()
assert idx.is_contiguous()
assert weight.is_contiguous()
B, c, m = features.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, c, n)
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, N) tensor with gradients of outputs
:return:
grad_features: (B, C, M) tensor with gradients of features
None:
None:
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
return grad_features, None, None
three_interpolate = ThreeInterpolate.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N) tensor of features to group
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
:return:
output: (B, C, npoint, nsample) tensor
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, nfeatures, nsample = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
ctx.for_backwards = (idx, N)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
:return:
grad_features: (B, C, N) gradient of the features
"""
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
return grad_features, None
grouping_operation = GroupingOperation.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
return new_features
class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features
/*
batch version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
const float *new_xyz = new_xyz_tensor.data<float>();
const float *xyz = xyz_tensor.data<float>();
int *idx = idx_tensor.data<int>();
ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx);
return 1;
}
/*
batch version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_gpu.h"
#include "cuda_utils.h"
__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
float radius2 = radius * radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int cnt = 0;
for (int k = 0; k < n; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
}
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
const float *new_xyz, const float *xyz, int *idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t err;
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel_fast<<<blocks, threads>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#ifndef _BALL_QUERY_GPU_H
#define _BALL_QUERY_GPU_H
#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
const float *xyz, const float *new_xyz, int *idx);
#endif
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#include <cmath>
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, TOTAL_THREADS), 1);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment