Commit c1d93158 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support velocity prediction with SECOND in NuScenes

parent 950c6712
...@@ -18,12 +18,14 @@ class AnchorHeadTemplate(nn.Module): ...@@ -18,12 +18,14 @@ class AnchorHeadTemplate(nn.Module):
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)( self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6) num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6),
**anchor_target_cfg.get('BOX_CODER_CONFIG', {})
) )
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors( anchors, self.num_anchors_per_location = self.generate_anchors(
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range,
anchor_ndim=self.box_coder.code_size
) )
self.anchors = [x.cuda() for x in anchors] self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg) self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg)
...@@ -32,13 +34,20 @@ class AnchorHeadTemplate(nn.Module): ...@@ -32,13 +34,20 @@ class AnchorHeadTemplate(nn.Module):
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod @staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range): def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7):
anchor_generator = AnchorGenerator( anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range, anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg anchor_generator_config=anchor_generator_cfg
) )
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg] feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size) anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
if anchor_ndim != 7:
for idx, anchors in enumerate(anchors_list):
pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7])
new_anchors = torch.cat((anchors, pad_zeros), dim=-1)
anchors_list[idx] = new_anchors
return anchors_list, num_anchors_per_location_list return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg): def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg):
......
...@@ -28,8 +28,8 @@ class ATSSTargetAssigner(object): ...@@ -28,8 +28,8 @@ class ATSSTargetAssigner(object):
cls_labels_list, reg_targets_list, reg_weights_list = [], [], [] cls_labels_list, reg_targets_list, reg_weights_list = [], [], []
for anchors in anchors_list: for anchors in anchors_list:
batch_size = gt_boxes_with_classes.shape[0] batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7] gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :7] gt_boxes = gt_boxes_with_classes[:, :, :-1]
if use_multihead: if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
else: else:
......
...@@ -29,13 +29,12 @@ class AxisAlignedTargetAssigner(object): ...@@ -29,13 +29,12 @@ class AxisAlignedTargetAssigner(object):
""" """
bbox_targets = [] bbox_targets = []
bbox_src_targets = []
cls_labels = [] cls_labels = []
reg_weights = [] reg_weights = []
batch_size = gt_boxes_with_classes.shape[0] batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7] gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :7] gt_boxes = gt_boxes_with_classes[:, :, :-1]
for k in range(batch_size): for k in range(batch_size):
cur_gt = gt_boxes[k] cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1 cnt = cur_gt.__len__() - 1
...@@ -109,13 +108,13 @@ class AxisAlignedTargetAssigner(object): ...@@ -109,13 +108,13 @@ class AxisAlignedTargetAssigner(object):
num_anchors = anchors.shape[0] num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0] num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0: if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes) anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1) anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors), anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax] anchor_to_gt_argmax]
......
...@@ -12,7 +12,9 @@ class RoIHeadTemplate(nn.Module): ...@@ -12,7 +12,9 @@ class RoIHeadTemplate(nn.Module):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.num_class = num_class self.num_class = num_class
self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)() self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)(
**self.model_cfg.TARGET_CONFIG.get('BOX_CODER_CONFIG', {})
)
self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG) self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG)
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None self.forward_ret_dict = None
......
...@@ -5,6 +5,7 @@ class ResidualCoder(object): ...@@ -5,6 +5,7 @@ class ResidualCoder(object):
def __init__(self, code_size=7, **kwargs): def __init__(self, code_size=7, **kwargs):
super().__init__() super().__init__()
self.code_size = code_size self.code_size = code_size
assert code_size in [7, 9]
@staticmethod @staticmethod
def encode_torch(boxes, anchors): def encode_torch(boxes, anchors):
......
...@@ -118,6 +118,8 @@ class WeightedSmoothL1Loss(nn.Module): ...@@ -118,6 +118,8 @@ class WeightedSmoothL1Loss(nn.Module):
loss: (B, #anchors) float tensor. loss: (B, #anchors) float tensor.
Weighted smooth l1 loss without reduction. Weighted smooth l1 loss without reduction.
""" """
target = torch.where(torch.isnan(target), input, target) # ignore nan targets
diff = input - target diff = input - target
# code-wise weighting # code-wise weighting
if self.code_weights is not None: if self.code_weights is not None:
......
...@@ -3,7 +3,7 @@ DATA_PATH: '../data/nuscenes' ...@@ -3,7 +3,7 @@ DATA_PATH: '../data/nuscenes'
VERSION: 'v1.0-trainval' VERSION: 'v1.0-trainval'
MAX_SWEEPS: 10 MAX_SWEEPS: 10
PRED_VELOCITY: False PRED_VELOCITY: True
DATA_SPLIT: { DATA_SPLIT: {
'train': train, 'train': train,
......
...@@ -146,13 +146,16 @@ MODEL: ...@@ -146,13 +146,16 @@ MODEL:
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}
LOSS_CONFIG: LOSS_CONFIG:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 2.0, 'loc_weight': 2.0,
'dir_weight': 0.2, 'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
} }
POST_PROCESSING: POST_PROCESSING:
......
...@@ -198,13 +198,16 @@ MODEL: ...@@ -198,13 +198,16 @@ MODEL:
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}
LOSS_CONFIG: LOSS_CONFIG:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 2.0, 'loc_weight': 2.0,
'dir_weight': 0.2, 'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
} }
POST_PROCESSING: POST_PROCESSING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment