Commit c1d93158 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support velocity prediction with SECOND in NuScenes

parent 950c6712
...@@ -18,12 +18,14 @@ class AnchorHeadTemplate(nn.Module): ...@@ -18,12 +18,14 @@ class AnchorHeadTemplate(nn.Module):
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)( self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6) num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6),
**anchor_target_cfg.get('BOX_CODER_CONFIG', {})
) )
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors( anchors, self.num_anchors_per_location = self.generate_anchors(
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range,
anchor_ndim=self.box_coder.code_size
) )
self.anchors = [x.cuda() for x in anchors] self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg) self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg)
...@@ -32,13 +34,20 @@ class AnchorHeadTemplate(nn.Module): ...@@ -32,13 +34,20 @@ class AnchorHeadTemplate(nn.Module):
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod @staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range): def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7):
anchor_generator = AnchorGenerator( anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range, anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg anchor_generator_config=anchor_generator_cfg
) )
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg] feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size) anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
if anchor_ndim != 7:
for idx, anchors in enumerate(anchors_list):
pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7])
new_anchors = torch.cat((anchors, pad_zeros), dim=-1)
anchors_list[idx] = new_anchors
return anchors_list, num_anchors_per_location_list return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg): def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg):
......
...@@ -28,8 +28,8 @@ class ATSSTargetAssigner(object): ...@@ -28,8 +28,8 @@ class ATSSTargetAssigner(object):
cls_labels_list, reg_targets_list, reg_weights_list = [], [], [] cls_labels_list, reg_targets_list, reg_weights_list = [], [], []
for anchors in anchors_list: for anchors in anchors_list:
batch_size = gt_boxes_with_classes.shape[0] batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7] gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :7] gt_boxes = gt_boxes_with_classes[:, :, :-1]
if use_multihead: if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
else: else:
......
...@@ -29,13 +29,12 @@ class AxisAlignedTargetAssigner(object): ...@@ -29,13 +29,12 @@ class AxisAlignedTargetAssigner(object):
""" """
bbox_targets = [] bbox_targets = []
bbox_src_targets = [] cls_labels = []
cls_labels = []
reg_weights = [] reg_weights = []
batch_size = gt_boxes_with_classes.shape[0] batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7] gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :7] gt_boxes = gt_boxes_with_classes[:, :, :-1]
for k in range(batch_size): for k in range(batch_size):
cur_gt = gt_boxes[k] cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1 cnt = cur_gt.__len__() - 1
...@@ -53,7 +52,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -53,7 +52,7 @@ class AxisAlignedTargetAssigner(object):
else: else:
feature_map_size = anchors.shape[:3] feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1]) anchors = anchors.view(-1, anchors.shape[-1])
single_target = self.assign_targets_single( single_target = self.assign_targets_single(
anchors, anchors,
cur_gt[mask], cur_gt[mask],
...@@ -68,7 +67,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -68,7 +67,7 @@ class AxisAlignedTargetAssigner(object):
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list], 'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(-1) for t in target_list] 'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
} }
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0) target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1) target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1) target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
...@@ -78,7 +77,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -78,7 +77,7 @@ class AxisAlignedTargetAssigner(object):
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list], 'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list] 'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
} }
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size) target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1) target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1) target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
...@@ -87,9 +86,9 @@ class AxisAlignedTargetAssigner(object): ...@@ -87,9 +86,9 @@ class AxisAlignedTargetAssigner(object):
bbox_targets.append(target_dict['box_reg_targets']) bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels']) cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights']) reg_weights.append(target_dict['reg_weights'])
bbox_targets = torch.stack(bbox_targets, dim=0) bbox_targets = torch.stack(bbox_targets, dim=0)
cls_labels = torch.stack(cls_labels, dim=0) cls_labels = torch.stack(cls_labels, dim=0)
reg_weights = torch.stack(reg_weights, dim=0) reg_weights = torch.stack(reg_weights, dim=0)
all_targets_dict = { all_targets_dict = {
...@@ -109,16 +108,16 @@ class AxisAlignedTargetAssigner(object): ...@@ -109,16 +108,16 @@ class AxisAlignedTargetAssigner(object):
num_anchors = anchors.shape[0] num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0] num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1 gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0: if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes) anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1) anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors), anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax] anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0) gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
gt_to_anchor_max = anchor_by_gt_overlap[ gt_to_anchor_max = anchor_by_gt_overlap[
...@@ -128,11 +127,11 @@ class AxisAlignedTargetAssigner(object): ...@@ -128,11 +127,11 @@ class AxisAlignedTargetAssigner(object):
gt_to_anchor_max[empty_gt_mask] = -1 gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero( anchors_with_max_overlap = torch.nonzero(
anchor_by_gt_overlap == gt_to_anchor_max)[:, 0] anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap] gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
gt_ids[anchors_with_max_overlap] = gt_inds_force.int() gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
pos_inds = anchor_to_gt_max >= matched_threshold pos_inds = anchor_to_gt_max >= matched_threshold
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds] gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
labels[pos_inds] = gt_classes[gt_inds_over_thresh] labels[pos_inds] = gt_classes[gt_inds_over_thresh]
...@@ -142,7 +141,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -142,7 +141,7 @@ class AxisAlignedTargetAssigner(object):
bg_inds = torch.arange(num_anchors) bg_inds = torch.arange(num_anchors)
fg_inds = torch.nonzero(labels > 0)[:, 0] fg_inds = torch.nonzero(labels > 0)[:, 0]
if self.pos_fraction is not None: if self.pos_fraction is not None:
num_fg = int(self.pos_fraction * self.sample_size) num_fg = int(self.pos_fraction * self.sample_size)
if len(fg_inds) > num_fg: if len(fg_inds) > num_fg:
...@@ -170,7 +169,7 @@ class AxisAlignedTargetAssigner(object): ...@@ -170,7 +169,7 @@ class AxisAlignedTargetAssigner(object):
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors) bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
reg_weights = anchors.new_zeros((num_anchors,)) reg_weights = anchors.new_zeros((num_anchors,))
if self.norm_by_num_examples: if self.norm_by_num_examples:
num_examples = (labels >= 0).sum() num_examples = (labels >= 0).sum()
num_examples = num_examples if num_examples > 1.0 else 1.0 num_examples = num_examples if num_examples > 1.0 else 1.0
......
...@@ -12,7 +12,9 @@ class RoIHeadTemplate(nn.Module): ...@@ -12,7 +12,9 @@ class RoIHeadTemplate(nn.Module):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.num_class = num_class self.num_class = num_class
self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)() self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)(
**self.model_cfg.TARGET_CONFIG.get('BOX_CODER_CONFIG', {})
)
self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG) self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG)
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None self.forward_ret_dict = None
......
...@@ -5,6 +5,7 @@ class ResidualCoder(object): ...@@ -5,6 +5,7 @@ class ResidualCoder(object):
def __init__(self, code_size=7, **kwargs): def __init__(self, code_size=7, **kwargs):
super().__init__() super().__init__()
self.code_size = code_size self.code_size = code_size
assert code_size in [7, 9]
@staticmethod @staticmethod
def encode_torch(boxes, anchors): def encode_torch(boxes, anchors):
......
...@@ -118,6 +118,8 @@ class WeightedSmoothL1Loss(nn.Module): ...@@ -118,6 +118,8 @@ class WeightedSmoothL1Loss(nn.Module):
loss: (B, #anchors) float tensor. loss: (B, #anchors) float tensor.
Weighted smooth l1 loss without reduction. Weighted smooth l1 loss without reduction.
""" """
target = torch.where(torch.isnan(target), input, target) # ignore nan targets
diff = input - target diff = input - target
# code-wise weighting # code-wise weighting
if self.code_weights is not None: if self.code_weights is not None:
...@@ -184,4 +186,4 @@ def get_corner_loss_lidar(pred_bbox3d: torch.Tensor, gt_bbox3d: torch.Tensor): ...@@ -184,4 +186,4 @@ def get_corner_loss_lidar(pred_bbox3d: torch.Tensor, gt_bbox3d: torch.Tensor):
# (N, 8) # (N, 8)
corner_loss = WeightedSmoothL1Loss.smooth_l1_loss(corner_dist, beta=1.0) corner_loss = WeightedSmoothL1Loss.smooth_l1_loss(corner_dist, beta=1.0)
return corner_loss.mean(dim=1) return corner_loss.mean(dim=1)
\ No newline at end of file
...@@ -3,7 +3,7 @@ DATA_PATH: '../data/nuscenes' ...@@ -3,7 +3,7 @@ DATA_PATH: '../data/nuscenes'
VERSION: 'v1.0-trainval' VERSION: 'v1.0-trainval'
MAX_SWEEPS: 10 MAX_SWEEPS: 10
PRED_VELOCITY: False PRED_VELOCITY: True
DATA_SPLIT: { DATA_SPLIT: {
'train': train, 'train': train,
......
...@@ -146,13 +146,16 @@ MODEL: ...@@ -146,13 +146,16 @@ MODEL:
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}
LOSS_CONFIG: LOSS_CONFIG:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 2.0, 'loc_weight': 2.0,
'dir_weight': 0.2, 'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
} }
POST_PROCESSING: POST_PROCESSING:
......
...@@ -198,13 +198,16 @@ MODEL: ...@@ -198,13 +198,16 @@ MODEL:
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}
LOSS_CONFIG: LOSS_CONFIG:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 2.0, 'loc_weight': 2.0,
'dir_weight': 0.2, 'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
} }
POST_PROCESSING: POST_PROCESSING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment