Commit cf29d887 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support dense_head PointHeadBox

parent 0b21d8f9
...@@ -2,6 +2,7 @@ from .anchor_head_template import AnchorHeadTemplate ...@@ -2,6 +2,7 @@ from .anchor_head_template import AnchorHeadTemplate
from .anchor_head_single import AnchorHeadSingle from .anchor_head_single import AnchorHeadSingle
from .point_intra_part_head import PointIntraPartOffsetHead from .point_intra_part_head import PointIntraPartOffsetHead
from .point_head_simple import PointHeadSimple from .point_head_simple import PointHeadSimple
from .point_head_box import PointHeadBox
from .anchor_head_multi import AnchorHeadMulti from .anchor_head_multi import AnchorHeadMulti
__all__ = { __all__ = {
...@@ -9,5 +10,6 @@ __all__ = { ...@@ -9,5 +10,6 @@ __all__ = {
'AnchorHeadSingle': AnchorHeadSingle, 'AnchorHeadSingle': AnchorHeadSingle,
'PointIntraPartOffsetHead': PointIntraPartOffsetHead, 'PointIntraPartOffsetHead': PointIntraPartOffsetHead,
'PointHeadSimple': PointHeadSimple, 'PointHeadSimple': PointHeadSimple,
'PointHeadBox': PointHeadBox,
'AnchorHeadMulti': AnchorHeadMulti, 'AnchorHeadMulti': AnchorHeadMulti,
} }
import torch
from .point_head_template import PointHeadTemplate
from ...utils import box_coder_utils, box_utils
class PointHeadBox(PointHeadTemplate):
"""
A simple point-based segmentation head, which are used for PointRCNN.
Reference Paper: https://arxiv.org/abs/1812.04244
PointRCNN: 3D Object Proposal Generation and Detection from Point Cloud
"""
def __init__(self, num_class, input_channels, model_cfg, predict_boxes_when_training=False, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class)
self.predict_boxes_when_training = predict_boxes_when_training
self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels,
output_channels=num_class
)
target_cfg = self.model_cfg.TARGET_CONFIG
self.box_coder = getattr(box_coder_utils, target_cfg.BOX_CODER)(
**target_cfg.BOX_CODER_CONFIG
)
self.box_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.REG_FC,
input_channels=input_channels,
output_channels=self.box_coder.code_size
)
def assign_targets(self, input_dict):
"""
Args:
input_dict:
point_features: (N1 + N2 + N3 + ..., C)
batch_size:
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes (optional): (B, M, 8)
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_part_labels: (N1 + N2 + N3 + ..., 3)
"""
point_coords = input_dict['point_coords']
gt_boxes = input_dict['gt_boxes']
assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)
batch_size = gt_boxes.shape[0]
extend_gt_boxes = box_utils.enlarge_box3d(
gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH
).view(batch_size, -1, gt_boxes.shape[-1])
targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=False, ret_box_labels=True
)
return targets_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss()
point_loss_box, tb_dict_2 = self.get_box_layer_loss()
point_loss = point_loss_cls + point_loss_box
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2)
return point_loss, tb_dict
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
point_features_before_fusion: (N1 + N2 + N3 + ..., C)
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
point_labels (optional): (N1 + N2 + N3 + ...)
gt_boxes (optional): (B, M, 8)
Returns:
batch_dict:
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
"""
if self.model_cfg.get('USE_POINT_FEATURES_BEFORE_FUSION', False):
point_features = batch_dict['point_features_before_fusion']
else:
point_features = batch_dict['point_features']
point_cls_preds = self.cls_layers(point_features) # (total_points, num_class)
point_box_preds = self.box_layers(point_features) # (total_points, box_code_size)
point_cls_preds_max, _ = point_cls_preds.max(dim=-1)
batch_dict['point_cls_scores'] = torch.sigmoid(point_cls_preds_max)
ret_dict = {'point_cls_preds': point_cls_preds,
'point_box_preds': point_box_preds}
if self.training:
targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_box_labels'] = targets_dict['point_box_labels']
if not self.training or self.predict_boxes_when_training:
point_cls_preds, point_box_preds = self.generate_predicted_boxes(
points=batch_dict['point_coords'][:, 1:4],
point_cls_preds=point_cls_preds, point_box_preds=point_box_preds
)
batch_dict['batch_cls_preds'] = point_cls_preds
batch_dict['batch_box_preds'] = point_box_preds
batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
batch_dict['cls_preds_normalized'] = False
self.forward_ret_dict = ret_dict
return batch_dict
...@@ -19,7 +19,17 @@ class PointHeadTemplate(nn.Module): ...@@ -19,7 +19,17 @@ class PointHeadTemplate(nn.Module):
'cls_loss_func', 'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0) loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
) )
self.reg_loss_func = F.smooth_l1_loss if losses_cfg.get('LOSS_REG', None) == 'smooth-l1' else F.l1_loss reg_loss_type = losses_cfg.get('LOSS_REG', None)
if reg_loss_type == 'smooth-l1':
self.reg_loss_func = F.smooth_l1_loss
elif reg_loss_type == 'l1':
self.reg_loss_func = F.l1_loss
elif reg_loss_type == 'WeightedSmoothL1Loss':
self.reg_loss_func = loss_utils.WeightedSmoothL1Loss(
code_weights=losses_cfg.LOSS_WEIGHTS.get('code_weights', None)
)
else:
self.reg_loss_func = F.smooth_l1_loss
@staticmethod @staticmethod
def make_fc_layers(fc_cfg, input_channels, output_channels): def make_fc_layers(fc_cfg, input_channels, output_channels):
...@@ -88,11 +98,15 @@ class PointHeadTemplate(nn.Module): ...@@ -88,11 +98,15 @@ class PointHeadTemplate(nn.Module):
raise NotImplementedError raise NotImplementedError
gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]] gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, 7].long() point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
point_cls_labels[bs_mask] = point_cls_labels_single point_cls_labels[bs_mask] = point_cls_labels_single
if ret_box_labels: if ret_box_labels:
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8)) point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
fg_point_box_labels = self.box_coder.encode_torch(points_single[fg_flag], gt_box_of_fg_points) fg_point_box_labels = self.box_coder.encode_torch(
gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
gt_classes=gt_box_of_fg_points[:, -1].long()
)
point_box_labels_single[fg_flag] = fg_point_box_labels point_box_labels_single[fg_flag] = fg_point_box_labels
point_box_labels[bs_mask] = point_box_labels_single point_box_labels[bs_mask] = point_box_labels_single
...@@ -149,5 +163,39 @@ class PointHeadTemplate(nn.Module): ...@@ -149,5 +163,39 @@ class PointHeadTemplate(nn.Module):
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight'] point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
return point_loss_part, {'point_loss_part': point_loss_part.item()} return point_loss_part, {'point_loss_part': point_loss_part.item()}
def get_box_layer_loss(self):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
point_box_labels = self.forward_ret_dict['point_box_labels']
point_box_preds = self.forward_ret_dict['point_box_preds']
reg_weights = pos_mask.float()
pos_normalizer = pos_mask.sum().float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
point_loss_box_src = self.reg_loss_func(
point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...]
)
point_loss_box = point_loss_box_src.sum()
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_box = point_loss_box * loss_weights_dict['point_box_weight']
return point_loss_box, {'point_loss_box': point_loss_box.item()}
def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds):
"""
Args:
points: (N, 3)
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
Returns:
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
"""
_, pred_classes = point_cls_preds.max(dim=-1)
point_box_preds = self.box_coder.decode_torch(point_box_preds, points, pred_classes + 1)
return point_cls_preds, point_box_preds
def forward(self, **kwargs): def forward(self, **kwargs):
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment