Commit 4337fc4e authored by Gus-Guo's avatar Gus-Guo Committed by Shaoshuai Shi
Browse files

simplify codes of axis aligned target assigner and update corresponding config

parent 3cd5c47b
...@@ -4,10 +4,10 @@ from .anchor_head_template import AnchorHeadTemplate ...@@ -4,10 +4,10 @@ from .anchor_head_template import AnchorHeadTemplate
class AnchorHeadSingle(AnchorHeadTemplate): class AnchorHeadSingle(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, grid_size, point_cloud_range, def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True): predict_boxes_when_training=True):
super().__init__( super().__init__(
model_cfg=model_cfg, num_class=num_class, grid_size=grid_size, point_cloud_range=point_cloud_range, model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range,
predict_boxes_when_training=predict_boxes_when_training predict_boxes_when_training=predict_boxes_when_training
) )
......
...@@ -8,10 +8,11 @@ from ...utils import box_coder_utils, loss_utils, common_utils ...@@ -8,10 +8,11 @@ from ...utils import box_coder_utils, loss_utils, common_utils
class AnchorHeadTemplate(nn.Module): class AnchorHeadTemplate(nn.Module):
def __init__(self, model_cfg, num_class, grid_size, point_cloud_range, predict_boxes_when_training): def __init__(self, model_cfg, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.num_class = num_class self.num_class = num_class
self.class_names = class_names
self.predict_boxes_when_training = predict_boxes_when_training self.predict_boxes_when_training = predict_boxes_when_training
self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False) self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False)
...@@ -20,11 +21,12 @@ class AnchorHeadTemplate(nn.Module): ...@@ -20,11 +21,12 @@ class AnchorHeadTemplate(nn.Module):
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6) num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6)
) )
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors( anchors, self.num_anchors_per_location = self.generate_anchors(
self.model_cfg.ANCHOR_GENERATOR_CONFIG, grid_size=grid_size, point_cloud_range=point_cloud_range anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range
) )
self.anchors = [x.cuda() for x in anchors] self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg) self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg)
self.forward_ret_dict = {} self.forward_ret_dict = {}
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
...@@ -39,7 +41,7 @@ class AnchorHeadTemplate(nn.Module): ...@@ -39,7 +41,7 @@ class AnchorHeadTemplate(nn.Module):
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size) anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
return anchors_list, num_anchors_per_location_list return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg): def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg):
if anchor_target_cfg.NAME == 'ATSS': if anchor_target_cfg.NAME == 'ATSS':
target_assigner = ATSSTargetAssigner( target_assigner = ATSSTargetAssigner(
topk=anchor_target_cfg.TOPK, topk=anchor_target_cfg.TOPK,
...@@ -49,6 +51,8 @@ class AnchorHeadTemplate(nn.Module): ...@@ -49,6 +51,8 @@ class AnchorHeadTemplate(nn.Module):
elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner': elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner':
target_assigner = AxisAlignedTargetAssigner( target_assigner = AxisAlignedTargetAssigner(
anchor_target_cfg=anchor_target_cfg, anchor_target_cfg=anchor_target_cfg,
anchor_generator_cfg=anchor_generator_cfg,
class_names=self.class_names,
box_coder=self.box_coder, box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT match_height=anchor_target_cfg.MATCH_HEIGHT
) )
......
...@@ -4,16 +4,20 @@ from ....ops.iou3d_nms import iou3d_nms_utils ...@@ -4,16 +4,20 @@ from ....ops.iou3d_nms import iou3d_nms_utils
class AxisAlignedTargetAssigner(object): class AxisAlignedTargetAssigner(object):
def __init__(self, anchor_target_cfg, box_coder, match_height=False): def __init__(self, anchor_target_cfg, anchor_generator_cfg, class_names, box_coder, match_height=False):
super().__init__() super().__init__()
self.box_coder = box_coder self.box_coder = box_coder
self.match_height = match_height self.match_height = match_height
self.class_names = class_names
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE self.sample_size = anchor_target_cfg.SAMPLE_SIZE
self.matched_thresholds = anchor_target_cfg.MATCHED_THRESHOLDS
self.unmatched_thresholds = anchor_target_cfg.UNMATCHED_THRESHOLDS
self.norm_by_num_examples = anchor_target_cfg.NORM_BY_NUM_EXAMPLES self.norm_by_num_examples = anchor_target_cfg.NORM_BY_NUM_EXAMPLES
self.matched_thresholds = {}
self.unmatched_thresholds = {}
for config in anchor_generator_cfg:
self.matched_thresholds[config['class_name']] = config['matched_threshold']
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False): def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False):
""" """
...@@ -41,8 +45,8 @@ class AxisAlignedTargetAssigner(object): ...@@ -41,8 +45,8 @@ class AxisAlignedTargetAssigner(object):
cur_gt_classes = gt_classes[k][:cnt + 1].int() cur_gt_classes = gt_classes[k][:cnt + 1].int()
target_list = [] target_list = []
for class_index, anchors in enumerate(all_anchors): for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
mask = torch.tensor([c == class_index + 1 for c in cur_gt_classes], dtype=torch.bool) mask = torch.tensor([self.class_names[c-1] == anchor_class_name for c in cur_gt_classes], dtype=torch.bool)
if use_multihead: if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
...@@ -54,8 +58,8 @@ class AxisAlignedTargetAssigner(object): ...@@ -54,8 +58,8 @@ class AxisAlignedTargetAssigner(object):
anchors, anchors,
cur_gt[mask], cur_gt[mask],
gt_classes=cur_gt_classes[mask], gt_classes=cur_gt_classes[mask],
matched_threshold=self.matched_thresholds[class_index], matched_threshold=self.matched_thresholds[anchor_class_name],
unmatched_threshold=self.unmatched_thresholds[class_index] unmatched_threshold=self.unmatched_thresholds[anchor_class_name]
) )
target_list.append(single_target) target_list.append(single_target)
if use_multihead: if use_multihead:
......
...@@ -14,6 +14,7 @@ class Detector3DTemplate(nn.Module): ...@@ -14,6 +14,7 @@ class Detector3DTemplate(nn.Module):
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.num_class = num_class self.num_class = num_class
self.dataset = dataset self.dataset = dataset
self.class_names = dataset.class_names
self.register_buffer('global_step', torch.LongTensor(1).zero_()) self.register_buffer('global_step', torch.LongTensor(1).zero_())
self.module_topology = [ self.module_topology = [
...@@ -119,6 +120,7 @@ class Detector3DTemplate(nn.Module): ...@@ -119,6 +120,7 @@ class Detector3DTemplate(nn.Module):
model_cfg=self.model_cfg.DENSE_HEAD, model_cfg=self.model_cfg.DENSE_HEAD,
input_channels=model_info_dict['num_bev_features'], input_channels=model_info_dict['num_bev_features'],
num_class=self.num_class if not self.model_cfg.DENSE_HEAD.CLASS_AGNOSTIC else 1, num_class=self.num_class if not self.model_cfg.DENSE_HEAD.CLASS_AGNOSTIC else 1,
class_names=self.class_names,
grid_size=model_info_dict['grid_size'], grid_size=model_info_dict['grid_size'],
point_cloud_range=model_info_dict['point_cloud_range'], point_cloud_range=model_info_dict['point_cloud_range'],
predict_boxes_when_training=self.model_cfg.get('ROI_HEAD', False) predict_boxes_when_training=self.model_cfg.get('ROI_HEAD', False)
......
...@@ -37,25 +37,34 @@ MODEL: ...@@ -37,25 +37,34 @@ MODEL:
ANCHOR_GENERATOR_CONFIG: [ ANCHOR_GENERATOR_CONFIG: [
{ {
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]], 'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78], 'anchor_bottom_heights': [-1.78],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
}, },
{ {
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]], 'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78], 'anchor_bottom_heights': [-1.78],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}, },
{ {
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]], 'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78], 'anchor_bottom_heights': [-1.78],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
} }
] ]
...@@ -63,8 +72,6 @@ MODEL: ...@@ -63,8 +72,6 @@ MODEL:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0 POS_FRACTION: -1.0
SAMPLE_SIZE: 512 SAMPLE_SIZE: 512
MATCHED_THRESHOLDS: [0.6, 0.5, 0.5]
UNMATCHED_THRESHOLDS: [0.45, 0.35, 0.35]
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
......
...@@ -78,25 +78,34 @@ MODEL: ...@@ -78,25 +78,34 @@ MODEL:
ANCHOR_GENERATOR_CONFIG: [ ANCHOR_GENERATOR_CONFIG: [
{ {
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]], 'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78], 'anchor_bottom_heights': [-1.78],
'align_center': False, 'align_center': False,
'feature_map_stride': 2 'feature_map_stride': 2,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
}, },
{ {
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]], 'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6], 'anchor_bottom_heights': [-0.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 2 'feature_map_stride': 2,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}, },
{ {
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]], 'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6], 'anchor_bottom_heights': [-0.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 2 'feature_map_stride': 2,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
} }
] ]
...@@ -104,8 +113,6 @@ MODEL: ...@@ -104,8 +113,6 @@ MODEL:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0 POS_FRACTION: -1.0
SAMPLE_SIZE: 512 SAMPLE_SIZE: 512
MATCHED_THRESHOLDS: [0.6, 0.5, 0.5]
UNMATCHED_THRESHOLDS: [0.45, 0.35, 0.35]
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
......
...@@ -12,7 +12,7 @@ DATA_CONFIG: ...@@ -12,7 +12,7 @@ DATA_CONFIG:
filter_by_difficulty: [-1], filter_by_difficulty: [-1],
} }
SAMPLE_GROUPS: ['Car:20','Pedestrian:15', 'Cyclist:15'] SAMPLE_GROUPS: ['Car:15','Pedestrian:10', 'Cyclist:10']
NUM_POINT_FEATURES: 4 NUM_POINT_FEATURES: 4
DATABASE_WITH_FAKELIDAR: False DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
...@@ -60,25 +60,34 @@ MODEL: ...@@ -60,25 +60,34 @@ MODEL:
ANCHOR_GENERATOR_CONFIG: [ ANCHOR_GENERATOR_CONFIG: [
{ {
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]], 'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78], 'anchor_bottom_heights': [-1.78],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
}, },
{ {
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]], 'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6], 'anchor_bottom_heights': [-0.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}, },
{ {
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]], 'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6], 'anchor_bottom_heights': [-0.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
} }
] ]
...@@ -86,8 +95,6 @@ MODEL: ...@@ -86,8 +95,6 @@ MODEL:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0 POS_FRACTION: -1.0
SAMPLE_SIZE: 512 SAMPLE_SIZE: 512
MATCHED_THRESHOLDS: [0.6, 0.5, 0.5]
UNMATCHED_THRESHOLDS: [0.45, 0.35, 0.35]
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
......
...@@ -37,25 +37,34 @@ MODEL: ...@@ -37,25 +37,34 @@ MODEL:
ANCHOR_GENERATOR_CONFIG: [ ANCHOR_GENERATOR_CONFIG: [
{ {
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]], 'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78], 'anchor_bottom_heights': [-1.78],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
}, },
{ {
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]], 'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6], 'anchor_bottom_heights': [-0.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}, },
{ {
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]], 'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6], 'anchor_bottom_heights': [-0.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
} }
] ]
...@@ -63,8 +72,6 @@ MODEL: ...@@ -63,8 +72,6 @@ MODEL:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0 POS_FRACTION: -1.0
SAMPLE_SIZE: 512 SAMPLE_SIZE: 512
MATCHED_THRESHOLDS: [0.6, 0.5, 0.5]
UNMATCHED_THRESHOLDS: [0.45, 0.35, 0.35]
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment