"...dynamo-run/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "ec2e730720daaf7405da456f90fc800c413a27d8"
Commit 3cd5c47b authored by Gus-Guo's avatar Gus-Guo Committed by Shaoshuai Shi
Browse files

simplify codes of multihead and update multihead config

parent 57e37335
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone
import torch import torch
class SingleHead(nn.Module): class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None): def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
super(SingleHead, self).__init__() super().__init__(encode_conv_cfg, input_channels)
if encode_conv_cfg is not None:
stride = encode_conv_cfg['stride']
layer_num = encode_conv_cfg['layer_num']
num_filters = input_channels
encode_conv = []
encode_conv.append(nn.Conv2d(num_filters, num_filters, kernel_size=1, stride=stride, bias=False))
for i in range(layer_num-1):
encode_conv.append(nn.Conv2d(num_filters, num_filters, 1, bias=False))
encode_conv.append(nn.BatchNorm2d(num_filters))
encode_conv.append(nn.ReLU(inplace=True))
self.encode_conv = nn.Sequential(*encode_conv)
else:
self.encode_conv = None
self.num_anchors_per_location = num_anchors_per_location self.num_anchors_per_location = num_anchors_per_location
self.num_class = num_class self.num_class = num_class
...@@ -51,9 +39,7 @@ class SingleHead(nn.Module): ...@@ -51,9 +39,7 @@ class SingleHead(nn.Module):
def forward(self, spatial_features_2d): def forward(self, spatial_features_2d):
ret_dict = {} ret_dict = {}
spatial_features_2d = super().forward({'spatial_features': spatial_features_2d})['spatial_features_2d']
if self.encode_conv is not None:
spatial_features_2d = self.encode_conv(spatial_features_2d)
cls_preds = self.conv_cls(spatial_features_2d) cls_preds = self.conv_cls(spatial_features_2d)
box_preds = self.conv_box(spatial_features_2d) box_preds = self.conv_box(spatial_features_2d)
...@@ -79,7 +65,7 @@ class SingleHead(nn.Module): ...@@ -79,7 +65,7 @@ class SingleHead(nn.Module):
dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS) dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
else: else:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous() dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
else: else:
dir_cls_preds = None dir_cls_preds = None
...@@ -90,9 +76,9 @@ class SingleHead(nn.Module): ...@@ -90,9 +76,9 @@ class SingleHead(nn.Module):
return ret_dict return ret_dict
class AnchorHeadMulti(AnchorHeadTemplate): class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, grid_size, point_cloud_range, predict_boxes_when_training=True): def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True):
super().__init__( super().__init__(
model_cfg=model_cfg, num_class=num_class, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
) )
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.make_multihead(input_channels) self.make_multihead(input_channels)
...@@ -103,9 +89,9 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -103,9 +89,9 @@ class AnchorHeadMulti(AnchorHeadTemplate):
rpn_heads = [] rpn_heads = []
class_names = [] class_names = []
for rpn_head_cfg in rpn_head_cfgs: for rpn_head_cfg in rpn_head_cfgs:
class_names.extend(rpn_head_cfg['head_cls_name']) class_names.extend(rpn_head_cfg['HEAD_CLS_NAME'])
for rpn_head_cfg in rpn_head_cfgs: for rpn_head_cfg in rpn_head_cfgs:
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['head_cls_name']]) num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
rpn_head = SingleHead(self.model_cfg, input_channels, self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg) rpn_head = SingleHead(self.model_cfg, input_channels, self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg)
rpn_heads.append(rpn_head) rpn_heads.append(rpn_head)
self.rpn_heads = nn.ModuleList(rpn_heads) self.rpn_heads = nn.ModuleList(rpn_heads)
......
...@@ -37,50 +37,68 @@ MODEL: ...@@ -37,50 +37,68 @@ MODEL:
USE_MULTI_HEAD: True USE_MULTI_HEAD: True
ANCHOR_GENERATOR_CONFIG: [ ANCHOR_GENERATOR_CONFIG: [
{ {
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]], 'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6], 'anchor_bottom_heights': [-1.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 16 'feature_map_stride': 8,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
}, },
{ {
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]], 'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6], 'anchor_bottom_heights': [-1.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}, },
{ {
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]], 'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57], 'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6], 'anchor_bottom_heights': [-1.6],
'align_center': False, 'align_center': False,
'feature_map_stride': 8 'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
} }
] ]
RPN_HEAD_CFGS: [ RPN_HEAD_CFGS: [
{ {
'head_cls_name': ['Car'], 'HEAD_CLS_NAME': ['Car'],
'stride': 2, 'LAYER_NUMS': [1],
'layer_num': 2 'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [1],
'NUM_UPSAMPLE_FILTERS': [512]
}, },
{ {
'head_cls_name': ['Pedestrian', 'Cyclist'], 'HEAD_CLS_NAME': ['Pedestrian'],
'stride': 1, 'LAYER_NUMS': [1],
'layer_num': 2 'LAYER_STRIDES': [1],
}, 'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [2],
'NUM_UPSAMPLE_FILTERS': [512]
},
{
'HEAD_CLS_NAME': ['Cyclist'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [2],
'NUM_UPSAMPLE_FILTERS': [512]
}
] ]
TARGET_ASSIGNER_CONFIG: TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0 POS_FRACTION: -1.0
SAMPLE_SIZE: 512 SAMPLE_SIZE: 512
MATCHED_THRESHOLDS: [0.6, 0.5, 0.5]
UNMATCHED_THRESHOLDS: [0.45, 0.35, 0.35]
NORM_BY_NUM_EXAMPLES: False NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False MATCH_HEIGHT: False
BOX_CODER: ResidualCoder BOX_CODER: ResidualCoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment