Commit 8075b170 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support PartA2_free models, update PointHead

parent 48400f13
...@@ -127,7 +127,7 @@ class PointHeadTemplate(nn.Module): ...@@ -127,7 +127,7 @@ class PointHeadTemplate(nn.Module):
} }
return targets_dict return targets_dict
def get_cls_layer_loss(self): def get_cls_layer_loss(self, tb_dict=None):
point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1) point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1)
point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class) point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class)
...@@ -145,13 +145,15 @@ class PointHeadTemplate(nn.Module): ...@@ -145,13 +145,15 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight'] point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight']
tb_dict = { if tb_dict is None:
tb_dict = {}
tb_dict.update({
'point_loss_cls': point_loss_cls.item(), 'point_loss_cls': point_loss_cls.item(),
'point_pos_num': pos_normalizer.item() 'point_pos_num': pos_normalizer.item()
} })
return point_loss_cls, tb_dict return point_loss_cls, tb_dict
def get_part_layer_loss(self): def get_part_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0 pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
pos_normalizer = max(1, (pos_mask > 0).sum().item()) pos_normalizer = max(1, (pos_mask > 0).sum().item())
point_part_labels = self.forward_ret_dict['point_part_labels'] point_part_labels = self.forward_ret_dict['point_part_labels']
...@@ -161,9 +163,12 @@ class PointHeadTemplate(nn.Module): ...@@ -161,9 +163,12 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight'] point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
return point_loss_part, {'point_loss_part': point_loss_part.item()} if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_part': point_loss_part.item()})
return point_loss_part, tb_dict
def get_box_layer_loss(self): def get_box_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0 pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
point_box_labels = self.forward_ret_dict['point_box_labels'] point_box_labels = self.forward_ret_dict['point_box_labels']
point_box_preds = self.forward_ret_dict['point_box_preds'] point_box_preds = self.forward_ret_dict['point_box_preds']
...@@ -179,7 +184,10 @@ class PointHeadTemplate(nn.Module): ...@@ -179,7 +184,10 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_box = point_loss_box * loss_weights_dict['point_box_weight'] point_loss_box = point_loss_box * loss_weights_dict['point_box_weight']
return point_loss_box, {'point_loss_box': point_loss_box.item()} if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_box': point_loss_box.item()})
return point_loss_box, tb_dict
def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds): def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds):
""" """
......
import torch import torch
from .point_head_template import PointHeadTemplate from .point_head_template import PointHeadTemplate
from ...utils import box_utils from ...utils import box_utils, box_coder_utils
class PointIntraPartOffsetHead(PointHeadTemplate): class PointIntraPartOffsetHead(PointHeadTemplate):
...@@ -9,8 +9,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -9,8 +9,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
Reference Paper: https://arxiv.org/abs/1907.03670 Reference Paper: https://arxiv.org/abs/1907.03670
From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
""" """
def __init__(self, num_class, input_channels, model_cfg, **kwargs): def __init__(self, num_class, input_channels, model_cfg, predict_boxes_when_training=False, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class) super().__init__(model_cfg=model_cfg, num_class=num_class)
self.predict_boxes_when_training = predict_boxes_when_training
self.cls_layers = self.make_fc_layers( self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC, fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels, input_channels=input_channels,
...@@ -21,6 +22,18 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -21,6 +22,18 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
input_channels=input_channels, input_channels=input_channels,
output_channels=3 output_channels=3
) )
target_cfg = self.model_cfg.TARGET_CONFIG
if target_cfg.get('BOX_CODER', None) is not None:
self.box_coder = getattr(box_coder_utils, target_cfg.BOX_CODER)(
**target_cfg.BOX_CODER_CONFIG
)
self.box_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.REG_FC,
input_channels=input_channels,
output_channels=self.box_coder.code_size
)
else:
self.box_layers = None
def assign_targets(self, input_dict): def assign_targets(self, input_dict):
""" """
...@@ -46,19 +59,20 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -46,19 +59,20 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
targets_dict = self.assign_stack_targets( targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes, points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False, set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=True ret_part_labels=True, ret_box_labels=(self.box_layers is not None)
) )
return targets_dict return targets_dict
def get_loss(self, tb_dict=None): def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss() point_loss_cls, tb_dict = self.get_cls_layer_loss(tb_dict)
point_loss_part, tb_dict_2 = self.get_part_layer_loss() point_loss_part, tb_dict = self.get_part_layer_loss(tb_dict)
point_loss = point_loss_cls + point_loss_part point_loss = point_loss_cls + point_loss_part
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2) if self.box_layers is not None:
point_loss_box, tb_dict = self.get_box_layer_loss(tb_dict)
point_loss += point_loss_box
return point_loss, tb_dict return point_loss, tb_dict
def forward(self, batch_dict): def forward(self, batch_dict):
...@@ -83,6 +97,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -83,6 +97,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
'point_cls_preds': point_cls_preds, 'point_cls_preds': point_cls_preds,
'point_part_preds': point_part_preds, 'point_part_preds': point_part_preds,
} }
if self.box_layers is not None:
point_box_preds = self.box_layers(point_features)
ret_dict['point_box_preds'] = point_box_preds
point_cls_scores = torch.sigmoid(point_cls_preds) point_cls_scores = torch.sigmoid(point_cls_preds)
point_part_offset = torch.sigmoid(point_part_preds) point_part_offset = torch.sigmoid(point_part_preds)
...@@ -93,6 +110,17 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -93,6 +110,17 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
targets_dict = self.assign_targets(batch_dict) targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels'] ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_part_labels'] = targets_dict.get('point_part_labels') ret_dict['point_part_labels'] = targets_dict.get('point_part_labels')
self.forward_ret_dict = ret_dict ret_dict['point_box_labels'] = targets_dict.get('point_box_labels')
if self.box_layers is not None and (not self.training or self.predict_boxes_when_training):
point_cls_preds, point_box_preds = self.generate_predicted_boxes(
points=batch_dict['point_coords'][:, 1:4],
point_cls_preds=point_cls_preds, point_box_preds=ret_dict['point_box_preds']
)
batch_dict['batch_cls_preds'] = point_cls_preds
batch_dict['batch_box_preds'] = point_box_preds
batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
batch_dict['cls_preds_normalized'] = False
self.forward_ret_dict = ret_dict
return batch_dict return batch_dict
...@@ -15,9 +15,10 @@ MODEL: ...@@ -15,9 +15,10 @@ MODEL:
RETURN_ENCODED_TENSOR: False RETURN_ENCODED_TENSOR: False
POINT_HEAD: POINT_HEAD:
NAME: PointHeadBox NAME: PointIntraPartOffsetHead
CLS_FC: [256, 256] CLS_FC: [128, 128]
REG_FC: [256, 256] PART_FC: [128, 128]
REG_FC: [128, 128]
CLASS_AGNOSTIC: False CLASS_AGNOSTIC: False
USE_POINT_FEATURES_BEFORE_FUSION: False USE_POINT_FEATURES_BEFORE_FUSION: False
TARGET_CONFIG: TARGET_CONFIG:
...@@ -37,6 +38,7 @@ MODEL: ...@@ -37,6 +38,7 @@ MODEL:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'point_cls_weight': 1.0, 'point_cls_weight': 1.0,
'point_box_weight': 1.0, 'point_box_weight': 1.0,
'point_part_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
} }
......
CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/kitti_dataset.yaml
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: sample_points
NUM_POINTS: {
'train': 16384,
'test': 16384
}
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': False
}
MODEL:
NAME: PointRCNN
BACKBONE_3D:
NAME: PointNet2MSG
SA_CONFIG:
NPOINTS: [4096, 1024, 256, 64]
RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
NSAMPLE: [[16, 32], [16, 32], [16, 32], [16, 32]]
MLPS: [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]]
FP_MLPS: [[128, 128], [256, 256], [512, 512], [512, 512]]
POINT_HEAD:
NAME: PointHeadBox
CLS_FC: [256, 256]
REG_FC: [256, 256]
CLASS_AGNOSTIC: False
USE_POINT_FEATURES_BEFORE_FUSION: False
TARGET_CONFIG:
GT_EXTRA_WIDTH: [0.2, 0.2, 0.2]
BOX_CODER: PointResidualCoder
BOX_CODER_CONFIG: {
'use_mean_size': True,
'mean_size': [
[3.9, 1.6, 1.56],
[0.8, 0.6, 1.73],
[1.76, 0.6, 1.73]
]
}
LOSS_CONFIG:
LOSS_REG: WeightedSmoothL1Loss
LOSS_WEIGHTS: {
'point_cls_weight': 1.0,
'point_box_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
ROI_HEAD:
NAME: PartA2FCHead
CLASS_AGNOSTIC: True
SHARED_FC: [256, 256, 256]
CLS_FC: [256, 256]
REG_FC: [256, 256]
DP_RATIO: 0.3
DISABLE_PART: True
SEG_MASK_SCORE_THRESH: 0.0
NMS_CONFIG:
TRAIN:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 512
NMS_THRESH: 0.8
TEST:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 100
NMS_THRESH: 0.85
ROI_AWARE_POOL:
POOL_SIZE: 12
NUM_FEATURES: 128
MAX_POINTS_PER_VOXEL: 128
TARGET_CONFIG:
BOX_CODER: ResidualCoder
ROI_PER_IMAGE: 128
FG_RATIO: 0.5
SAMPLE_ROI_BY_EACH_CLASS: True
CLS_SCORE_TYPE: roi_iou
CLS_FG_THRESH: 0.75
CLS_BG_THRESH: 0.25
CLS_BG_THRESH_LO: 0.1
HARD_BG_RATIO: 0.8
REG_FG_THRESH: 0.65
LOSS_CONFIG:
CLS_LOSS: BinaryCrossEntropy
REG_LOSS: smooth-l1
CORNER_LOSS_REGULARIZATION: True
LOSS_WEIGHTS: {
'rcnn_cls_weight': 1.0,
'rcnn_reg_weight': 1.0,
'rcnn_corner_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.1
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
OPTIMIZATION:
OPTIMIZER: adam_onecycle
LR: 0.01
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment