Commit ba28c08f authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: separte regression heads training

parent 72088ee3
......@@ -17,28 +17,66 @@ class SingleHead(BaseBEVBackbone):
self.separate_reg_config = separate_reg_config
self.register_buffer('head_label_indices', head_label_indices)
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
if self.separate_reg_config is not None:
code_size_cnt = 0
self.conv_box = nn.ModuleDict()
self.conv_box_names = []
for reg_config in self.separate_reg_config:
num_middle_conv = self.separate_reg_config.NUM_MIDDLE_CONV
num_middle_filter = self.separate_reg_config.NUM_MIDDLE_FILTER
conv_cls_list = []
c_in = input_channels
for k in range(num_middle_conv):
conv_cls_list.extend([
nn.Conv2d(
c_in, num_middle_filter,
kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(num_middle_filter),
nn.ReLU()
])
c_in = num_middle_filter
conv_cls_list.append(nn.Conv2d(
c_in, self.num_anchors_per_location * self.num_class,
kernel_size=3, stride=1, padding=1
))
self.conv_cls = nn.Sequential(*conv_cls_list)
for reg_config in self.separate_reg_config.REG_LIST:
reg_name, reg_channel = reg_config.split(':')
cur_conv = nn.Conv2d(
input_channels, self.num_anchors_per_location * reg_channel,
reg_channel = int(reg_channel)
cur_conv_list = []
c_in = input_channels
for k in range(num_middle_conv):
cur_conv_list.extend([
nn.Conv2d(
c_in, num_middle_filter,
kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(num_middle_filter),
nn.ReLU()
])
c_in = num_middle_filter
cur_conv_list.append(nn.Conv2d(
c_in, self.num_anchors_per_location * int(reg_channel),
kernel_size=3, stride=1, padding=1, bias=True
)
nn.init.kaiming_normal_(cur_conv.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(cur_conv.bias, 0)
))
code_size_cnt += reg_channel
self.conv_box[f'conv_{reg_name}'] = cur_conv
self.conv_box[f'conv_{reg_name}'] = nn.Sequential(*cur_conv_list)
self.conv_box_names.append(f'conv_{reg_name}')
for m in self.conv_box.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
assert code_size_cnt == code_size, f'Code size does not match: {code_size_cnt}:{code_size}'
else:
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
self.conv_box = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.code_size,
kernel_size=1
......@@ -57,7 +95,10 @@ class SingleHead(BaseBEVBackbone):
def init_weights(self):
pi = 0.01
if isinstance(self.conv_cls, nn.Conv2d):
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
else:
nn.init.constant_(self.conv_cls[-1].bias, -np.log((1 - pi) / pi))
def forward(self, spatial_features_2d):
ret_dict = {}
......@@ -65,12 +106,12 @@ class SingleHead(BaseBEVBackbone):
cls_preds = self.conv_cls(spatial_features_2d)
if self.separate_reg_config is not None:
if self.separate_reg_config is None:
box_preds = self.conv_box(spatial_features_2d)
else:
box_preds_list = []
for reg_name in self.conv_box_names:
box_preds_list.append(self.conv_box[f'conv_{reg_name}'](spatial_features_2d))
box_preds_list.append(self.conv_box[reg_name](spatial_features_2d))
box_preds = torch.cat(box_preds_list, dim=1)
if not self.use_multihead:
......
......@@ -31,7 +31,6 @@ MODEL:
NAME: AnchorHeadMulti
CLASS_AGNOSTIC: False
USE_DIRECTION_CLASSIFIER: True
DIR_OFFSET: 0.78539
DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2
......@@ -145,43 +144,29 @@ MODEL:
RPN_HEAD_CFGS: [
{
'HEAD_CLS_NAME': ['car'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['truck', 'construction_vehicle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['bus', 'trailer'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['barrier'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['motorcycle', 'bicycle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
]
SEPARATE_REG_CONFIG: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2']
SEPARATE_REG_CONFIG:
NUM_MIDDLE_CONV: 1
NUM_MIDDLE_FILTER: 64
REG_LIST: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2']
TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0
......@@ -203,7 +188,7 @@ MODEL:
'cls_weight': 1.0,
'loc_weight': 0.25,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
POST_PROCESSING:
......
......@@ -49,7 +49,6 @@ MODEL:
NAME: AnchorHeadMulti
CLASS_AGNOSTIC: False
USE_DIRECTION_CLASSIFIER: True
DIR_OFFSET: 0.78539
DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2
......@@ -164,42 +163,27 @@ MODEL:
RPN_HEAD_CFGS: [
{
'HEAD_CLS_NAME': ['car'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['truck', 'construction_vehicle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['bus', 'trailer'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['barrier'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['motorcycle', 'bicycle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
]
SEPARATE_REG_CONFIG: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2']
SEPARATE_REG_CONFIG:
NUM_MIDDLE_CONV: 1
NUM_MIDDLE_FILTER: 64
REG_LIST: ['reg:2', 'height:1', 'size:3', 'angle:2', 'velo:2']
TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner
......@@ -222,7 +206,7 @@ MODEL:
'cls_weight': 1.0,
'loc_weight': 0.25,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
POST_PROCESSING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment