Commit c1d93158 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support velocity prediction with SECOND in NuScenes

parent 950c6712
......@@ -18,12 +18,14 @@ class AnchorHeadTemplate(nn.Module):
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6)
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6),
**anchor_target_cfg.get('BOX_CODER_CONFIG', {})
)
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors(
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range,
anchor_ndim=self.box_coder.code_size
)
self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg)
......@@ -32,13 +34,20 @@ class AnchorHeadTemplate(nn.Module):
self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range):
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7):
anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg
)
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
if anchor_ndim != 7:
for idx, anchors in enumerate(anchors_list):
pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7])
new_anchors = torch.cat((anchors, pad_zeros), dim=-1)
anchors_list[idx] = new_anchors
return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg):
......
......@@ -28,8 +28,8 @@ class ATSSTargetAssigner(object):
cls_labels_list, reg_targets_list, reg_weights_list = [], [], []
for anchors in anchors_list:
batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7]
gt_boxes = gt_boxes_with_classes[:, :, :7]
gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :-1]
if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
else:
......
......@@ -29,13 +29,12 @@ class AxisAlignedTargetAssigner(object):
"""
bbox_targets = []
bbox_src_targets = []
cls_labels = []
cls_labels = []
reg_weights = []
batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7]
gt_boxes = gt_boxes_with_classes[:, :, :7]
gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :-1]
for k in range(batch_size):
cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1
......@@ -53,7 +52,7 @@ class AxisAlignedTargetAssigner(object):
else:
feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1])
single_target = self.assign_targets_single(
anchors,
cur_gt[mask],
......@@ -68,7 +67,7 @@ class AxisAlignedTargetAssigner(object):
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(-1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=0)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=0).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=0).view(-1)
......@@ -78,7 +77,7 @@ class AxisAlignedTargetAssigner(object):
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
}
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)
......@@ -87,9 +86,9 @@ class AxisAlignedTargetAssigner(object):
bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights'])
bbox_targets = torch.stack(bbox_targets, dim=0)
cls_labels = torch.stack(cls_labels, dim=0)
reg_weights = torch.stack(reg_weights, dim=0)
all_targets_dict = {
......@@ -109,16 +108,16 @@ class AxisAlignedTargetAssigner(object):
num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1]
labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes)
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax]
anchor_to_gt_argmax]
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
gt_to_anchor_max = anchor_by_gt_overlap[
......@@ -128,11 +127,11 @@ class AxisAlignedTargetAssigner(object):
gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero(
anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
gt_ids[anchors_with_max_overlap] = gt_inds_force.int()
pos_inds = anchor_to_gt_max >= matched_threshold
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]
labels[pos_inds] = gt_classes[gt_inds_over_thresh]
......@@ -142,7 +141,7 @@ class AxisAlignedTargetAssigner(object):
bg_inds = torch.arange(num_anchors)
fg_inds = torch.nonzero(labels > 0)[:, 0]
if self.pos_fraction is not None:
num_fg = int(self.pos_fraction * self.sample_size)
if len(fg_inds) > num_fg:
......@@ -170,7 +169,7 @@ class AxisAlignedTargetAssigner(object):
bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)
reg_weights = anchors.new_zeros((num_anchors,))
if self.norm_by_num_examples:
num_examples = (labels >= 0).sum()
num_examples = num_examples if num_examples > 1.0 else 1.0
......
......@@ -12,7 +12,9 @@ class RoIHeadTemplate(nn.Module):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)()
self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)(
**self.model_cfg.TARGET_CONFIG.get('BOX_CODER_CONFIG', {})
)
self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG)
self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None
......
......@@ -5,6 +5,7 @@ class ResidualCoder(object):
def __init__(self, code_size=7, **kwargs):
super().__init__()
self.code_size = code_size
assert code_size in [7, 9]
@staticmethod
def encode_torch(boxes, anchors):
......
......@@ -118,6 +118,8 @@ class WeightedSmoothL1Loss(nn.Module):
loss: (B, #anchors) float tensor.
Weighted smooth l1 loss without reduction.
"""
target = torch.where(torch.isnan(target), input, target) # ignore nan targets
diff = input - target
# code-wise weighting
if self.code_weights is not None:
......@@ -184,4 +186,4 @@ def get_corner_loss_lidar(pred_bbox3d: torch.Tensor, gt_bbox3d: torch.Tensor):
# (N, 8)
corner_loss = WeightedSmoothL1Loss.smooth_l1_loss(corner_dist, beta=1.0)
return corner_loss.mean(dim=1)
\ No newline at end of file
return corner_loss.mean(dim=1)
......@@ -3,7 +3,7 @@ DATA_PATH: '../data/nuscenes'
VERSION: 'v1.0-trainval'
MAX_SWEEPS: 10
PRED_VELOCITY: False
PRED_VELOCITY: True
DATA_SPLIT: {
'train': train,
......
......@@ -146,13 +146,16 @@ MODEL:
NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False
BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
POST_PROCESSING:
......
......@@ -198,13 +198,16 @@ MODEL:
NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False
BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
POST_PROCESSING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment