Commit ba28c08f authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: separte regression heads training

parent 72088ee3
...@@ -17,28 +17,66 @@ class SingleHead(BaseBEVBackbone): ...@@ -17,28 +17,66 @@ class SingleHead(BaseBEVBackbone):
self.separate_reg_config = separate_reg_config self.separate_reg_config = separate_reg_config
self.register_buffer('head_label_indices', head_label_indices) self.register_buffer('head_label_indices', head_label_indices)
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
if self.separate_reg_config is not None: if self.separate_reg_config is not None:
code_size_cnt = 0 code_size_cnt = 0
self.conv_box = nn.ModuleDict() self.conv_box = nn.ModuleDict()
self.conv_box_names = [] self.conv_box_names = []
for reg_config in self.separate_reg_config: num_middle_conv = self.separate_reg_config.NUM_MIDDLE_CONV
num_middle_filter = self.separate_reg_config.NUM_MIDDLE_FILTER
conv_cls_list = []
c_in = input_channels
for k in range(num_middle_conv):
conv_cls_list.extend([
nn.Conv2d(
c_in, num_middle_filter,
kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(num_middle_filter),
nn.ReLU()
])
c_in = num_middle_filter
conv_cls_list.append(nn.Conv2d(
c_in, self.num_anchors_per_location * self.num_class,
kernel_size=3, stride=1, padding=1
))
self.conv_cls = nn.Sequential(*conv_cls_list)
for reg_config in self.separate_reg_config.REG_LIST:
reg_name, reg_channel = reg_config.split(':') reg_name, reg_channel = reg_config.split(':')
cur_conv = nn.Conv2d( reg_channel = int(reg_channel)
input_channels, self.num_anchors_per_location * reg_channel, cur_conv_list = []
c_in = input_channels
for k in range(num_middle_conv):
cur_conv_list.extend([
nn.Conv2d(
c_in, num_middle_filter,
kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(num_middle_filter),
nn.ReLU()
])
c_in = num_middle_filter
cur_conv_list.append(nn.Conv2d(
c_in, self.num_anchors_per_location * int(reg_channel),
kernel_size=3, stride=1, padding=1, bias=True kernel_size=3, stride=1, padding=1, bias=True
) ))
nn.init.kaiming_normal_(cur_conv.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(cur_conv.bias, 0)
code_size_cnt += reg_channel code_size_cnt += reg_channel
self.conv_box[f'conv_{reg_name}'] = cur_conv self.conv_box[f'conv_{reg_name}'] = nn.Sequential(*cur_conv_list)
self.conv_box_names.append(f'conv_{reg_name}') self.conv_box_names.append(f'conv_{reg_name}')
for m in self.conv_box.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
assert code_size_cnt == code_size, f'Code size does not match: {code_size_cnt}:{code_size}' assert code_size_cnt == code_size, f'Code size does not match: {code_size_cnt}:{code_size}'
else: else:
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
self.conv_box = nn.Conv2d( self.conv_box = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.code_size, input_channels, self.num_anchors_per_location * self.code_size,
kernel_size=1 kernel_size=1
...@@ -57,7 +95,10 @@ class SingleHead(BaseBEVBackbone): ...@@ -57,7 +95,10 @@ class SingleHead(BaseBEVBackbone):
def init_weights(self): def init_weights(self):
pi = 0.01 pi = 0.01
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi)) if isinstance(self.conv_cls, nn.Conv2d):
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
else:
nn.init.constant_(self.conv_cls[-1].bias, -np.log((1 - pi) / pi))
def forward(self, spatial_features_2d): def forward(self, spatial_features_2d):
ret_dict = {} ret_dict = {}
...@@ -65,12 +106,12 @@ class SingleHead(BaseBEVBackbone): ...@@ -65,12 +106,12 @@ class SingleHead(BaseBEVBackbone):
cls_preds = self.conv_cls(spatial_features_2d) cls_preds = self.conv_cls(spatial_features_2d)
if self.separate_reg_config is not None: if self.separate_reg_config is None:
box_preds = self.conv_box(spatial_features_2d) box_preds = self.conv_box(spatial_features_2d)
else: else:
box_preds_list = [] box_preds_list = []
for reg_name in self.conv_box_names: for reg_name in self.conv_box_names:
box_preds_list.append(self.conv_box[f'conv_{reg_name}'](spatial_features_2d)) box_preds_list.append(self.conv_box[reg_name](spatial_features_2d))
box_preds = torch.cat(box_preds_list, dim=1) box_preds = torch.cat(box_preds_list, dim=1)
if not self.use_multihead: if not self.use_multihead:
......
...@@ -31,7 +31,6 @@ MODEL: ...@@ -31,7 +31,6 @@ MODEL:
NAME: AnchorHeadMulti NAME: AnchorHeadMulti
CLASS_AGNOSTIC: False CLASS_AGNOSTIC: False
USE_DIRECTION_CLASSIFIER: True
DIR_OFFSET: 0.78539 DIR_OFFSET: 0.78539
DIR_LIMIT_OFFSET: 0.0 DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2 NUM_DIR_BINS: 2
...@@ -145,43 +144,29 @@ MODEL: ...@@ -145,43 +144,29 @@ MODEL:
RPN_HEAD_CFGS: [ RPN_HEAD_CFGS: [
{ {
'HEAD_CLS_NAME': ['car'], 'HEAD_CLS_NAME': ['car'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['truck', 'construction_vehicle'], 'HEAD_CLS_NAME': ['truck', 'construction_vehicle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['bus', 'trailer'], 'HEAD_CLS_NAME': ['bus', 'trailer'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['barrier'], 'HEAD_CLS_NAME': ['barrier'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['motorcycle', 'bicycle'], 'HEAD_CLS_NAME': ['motorcycle', 'bicycle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'], 'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
] ]
SEPARATE_REG_CONFIG: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2'] SEPARATE_REG_CONFIG:
NUM_MIDDLE_CONV: 1
NUM_MIDDLE_FILTER: 64
REG_LIST: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2']
TARGET_ASSIGNER_CONFIG: TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0 POS_FRACTION: -1.0
...@@ -203,7 +188,7 @@ MODEL: ...@@ -203,7 +188,7 @@ MODEL:
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 0.25, 'loc_weight': 0.25,
'dir_weight': 0.2, 'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
} }
POST_PROCESSING: POST_PROCESSING:
......
...@@ -49,7 +49,6 @@ MODEL: ...@@ -49,7 +49,6 @@ MODEL:
NAME: AnchorHeadMulti NAME: AnchorHeadMulti
CLASS_AGNOSTIC: False CLASS_AGNOSTIC: False
USE_DIRECTION_CLASSIFIER: True
DIR_OFFSET: 0.78539 DIR_OFFSET: 0.78539
DIR_LIMIT_OFFSET: 0.0 DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2 NUM_DIR_BINS: 2
...@@ -164,42 +163,27 @@ MODEL: ...@@ -164,42 +163,27 @@ MODEL:
RPN_HEAD_CFGS: [ RPN_HEAD_CFGS: [
{ {
'HEAD_CLS_NAME': ['car'], 'HEAD_CLS_NAME': ['car'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['truck', 'construction_vehicle'], 'HEAD_CLS_NAME': ['truck', 'construction_vehicle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['bus', 'trailer'], 'HEAD_CLS_NAME': ['bus', 'trailer'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['barrier'], 'HEAD_CLS_NAME': ['barrier'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['motorcycle', 'bicycle'], 'HEAD_CLS_NAME': ['motorcycle', 'bicycle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
{ {
'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'], 'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
}, },
] ]
SEPARATE_REG_CONFIG: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2'] SEPARATE_REG_CONFIG:
NUM_MIDDLE_CONV: 1
NUM_MIDDLE_FILTER: 64
REG_LIST: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2']
TARGET_ASSIGNER_CONFIG: TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
...@@ -222,7 +206,7 @@ MODEL: ...@@ -222,7 +206,7 @@ MODEL:
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 0.25, 'loc_weight': 0.25,
'dir_weight': 0.2, 'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2] 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
} }
POST_PROCESSING: POST_PROCESSING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment