"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "76279ff212dfbf16fd06a6e4a51e9bca02304b02"
Commit 3cd5c47b authored by Gus-Guo's avatar Gus-Guo Committed by Shaoshuai Shi
Browse files

simplify codes of multihead and update multihead config

parent 57e37335
import numpy as np
import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone
import torch
class SingleHead(nn.Module):
class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
super(SingleHead, self).__init__()
if encode_conv_cfg is not None:
stride = encode_conv_cfg['stride']
layer_num = encode_conv_cfg['layer_num']
num_filters = input_channels
encode_conv = []
encode_conv.append(nn.Conv2d(num_filters, num_filters, kernel_size=1, stride=stride, bias=False))
for i in range(layer_num-1):
encode_conv.append(nn.Conv2d(num_filters, num_filters, 1, bias=False))
encode_conv.append(nn.BatchNorm2d(num_filters))
encode_conv.append(nn.ReLU(inplace=True))
self.encode_conv = nn.Sequential(*encode_conv)
else:
self.encode_conv = None
super().__init__(encode_conv_cfg, input_channels)
self.num_anchors_per_location = num_anchors_per_location
self.num_class = num_class
......@@ -51,9 +39,7 @@ class SingleHead(nn.Module):
def forward(self, spatial_features_2d):
ret_dict = {}
if self.encode_conv is not None:
spatial_features_2d = self.encode_conv(spatial_features_2d)
spatial_features_2d = super().forward({'spatial_features': spatial_features_2d})['spatial_features_2d']
cls_preds = self.conv_cls(spatial_features_2d)
box_preds = self.conv_box(spatial_features_2d)
......@@ -79,7 +65,7 @@ class SingleHead(nn.Module):
dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
else:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
else:
dir_cls_preds = None
......@@ -90,9 +76,9 @@ class SingleHead(nn.Module):
return ret_dict
class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, grid_size, point_cloud_range, predict_boxes_when_training=True):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True):
super().__init__(
model_cfg=model_cfg, num_class=num_class, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
)
self.model_cfg = model_cfg
self.make_multihead(input_channels)
......@@ -103,9 +89,9 @@ class AnchorHeadMulti(AnchorHeadTemplate):
rpn_heads = []
class_names = []
for rpn_head_cfg in rpn_head_cfgs:
class_names.extend(rpn_head_cfg['head_cls_name'])
class_names.extend(rpn_head_cfg['HEAD_CLS_NAME'])
for rpn_head_cfg in rpn_head_cfgs:
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['head_cls_name']])
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
rpn_head = SingleHead(self.model_cfg, input_channels, self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg)
rpn_heads.append(rpn_head)
self.rpn_heads = nn.ModuleList(rpn_heads)
......
......@@ -37,50 +37,68 @@ MODEL:
USE_MULTI_HEAD: True
ANCHOR_GENERATOR_CONFIG: [
{
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6],
'align_center': False,
'feature_map_stride': 16
'feature_map_stride': 8,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
},
{
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6],
'align_center': False,
'feature_map_stride': 8
'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
},
{
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6],
'align_center': False,
'feature_map_stride': 8
'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}
]
RPN_HEAD_CFGS: [
{
'head_cls_name': ['Car'],
'stride': 2,
'layer_num': 2
'HEAD_CLS_NAME': ['Car'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [1],
'NUM_UPSAMPLE_FILTERS': [512]
},
{
'head_cls_name': ['Pedestrian', 'Cyclist'],
'stride': 1,
'layer_num': 2
},
'HEAD_CLS_NAME': ['Pedestrian'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [2],
'NUM_UPSAMPLE_FILTERS': [512]
},
{
'HEAD_CLS_NAME': ['Cyclist'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [2],
'NUM_UPSAMPLE_FILTERS': [512]
}
]
TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0
SAMPLE_SIZE: 512
MATCHED_THRESHOLDS: [0.6, 0.5, 0.5]
UNMATCHED_THRESHOLDS: [0.45, 0.35, 0.35]
NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False
BOX_CODER: ResidualCoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment