"launch/dynamo-run/src/flags.rs" did not exist on "a657ec612d558b3e4af1f141c7789122da9ebd5c"
Commit eb074cf2 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

refactoring some configs and codes of multi-head

parent 0ac2901c
......@@ -7,13 +7,20 @@ class BaseBEVBackbone(nn.Module):
super().__init__()
self.model_cfg = model_cfg
if self.model_cfg.get('LAYER_NUMS', None) is not None:
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
else:
layer_nums = layer_strides = num_filters = []
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
else:
upsample_strides = num_upsample_filters = []
num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]]
......
......@@ -6,13 +6,15 @@ import torch
class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None,
head_label_indices=None):
super().__init__(encode_conv_cfg, input_channels)
self.num_anchors_per_location = num_anchors_per_location
self.num_class = num_class
self.code_size = code_size
self.model_cfg = model_cfg
self.register_buffer('head_label_indices', head_label_indices)
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
......@@ -62,7 +64,8 @@ class SingleHead(BaseBEVBackbone):
dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
if self.use_multihead:
dir_cls_preds = dir_cls_preds.view(
-1, self.num_anchors_per_location, self.model_cfg.NUM_DIR_BINS, H, W).permute(0, 1, 3, 4, 2).contiguous()
-1, self.num_anchors_per_location, self.model_cfg.NUM_DIR_BINS, H, W).permute(0, 1, 3, 4,
2).contiguous()
dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
else:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
......@@ -110,10 +113,15 @@ class AnchorHeadMulti(AnchorHeadTemplate):
for rpn_head_cfg in rpn_head_cfgs:
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)]
for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
head_label_indices = torch.from_numpy(np.array([
self.class_names.index(cur_name) + 1 for cur_name in rpn_head_cfg['HEAD_CLS_NAME']
]))
rpn_head = SingleHead(
self.model_cfg, input_channels,
len(rpn_head_cfg['HEAD_CLS_NAME']) if self.separate_multihead else self.num_class,
num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg
num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg,
head_label_indices=head_label_indices
)
rpn_heads.append(rpn_head)
self.rpn_heads = nn.ModuleList(rpn_heads)
......@@ -137,8 +145,6 @@ class AnchorHeadMulti(AnchorHeadTemplate):
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False):
dir_cls_preds = [ret_dict['dir_cls_preds'] for ret_dict in ret_dicts]
ret['dir_cls_preds'] = dir_cls_preds if self.separate_multihead else torch.cat(dir_cls_preds, dim=1)
else:
dir_cls_preds = None
self.forward_ret_dict.update(ret)
......@@ -150,8 +156,24 @@ class AnchorHeadMulti(AnchorHeadTemplate):
else:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=data_dict['batch_size'],
cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
cls_preds=ret['cls_preds'], box_preds=ret['box_preds'], dir_cls_preds=ret['dir_cls_preds']
)
if isinstance(batch_cls_preds, list):
all_pred_labels = []
all_cls_preds = []
for idx, cls_pred in enumerate(batch_cls_preds):
pred_score, pred_head_label = torch.max(cls_pred, dim=-1)
pred_label = self.rpn_heads[idx].head_label_indices[pred_head_label]
all_pred_labels.append(pred_label)
all_cls_preds.append(pred_score[:, :, None])
batch_cls_preds = torch.cat(all_cls_preds, dim=1)
batch_pred_labels = torch.cat(all_pred_labels, dim=1)
data_dict['batch_pred_labels'] = batch_pred_labels
data_dict['has_class_labels'] = True
data_dict['batch_cls_preds'] = batch_cls_preds
data_dict['batch_box_preds'] = batch_box_preds
data_dict['cls_preds_normalized'] = False
......@@ -190,11 +212,12 @@ class AnchorHeadMulti(AnchorHeadTemplate):
cur_num_class = self.rpn_heads[idx].num_class
cls_pred = cls_pred.view(batch_size, -1, cur_num_class)
if self.separate_multihead:
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1], c_idx:c_idx+cur_num_class]
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1],
c_idx:c_idx + cur_num_class]
c_idx += cur_num_class
else:
one_hot_target = one_hot_targets[:, start_idx:start_idx+cls_pred.shape[1]]
cls_weight = cls_weights[:, start_idx:start_idx+cls_pred.shape[1]]
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1]]
cls_weight = cls_weights[:, start_idx:start_idx + cls_pred.shape[1]]
cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
......@@ -239,8 +262,8 @@ class AnchorHeadMulti(AnchorHeadTemplate):
box_pred = box_pred.view(batch_size, -1,
box_pred.shape[-1] // self.num_anchors_per_location if not self.use_multihead else
box_pred.shape[-1])
box_reg_target = box_reg_targets[:, start_idx:start_idx+box_pred.shape[1]]
reg_weight = reg_weights[:, start_idx:start_idx+box_pred.shape[1]]
box_reg_target = box_reg_targets[:, start_idx:start_idx + box_pred.shape[1]]
reg_weight = reg_weights[:, start_idx:start_idx + box_pred.shape[1]]
# sin(a - b) = sinacosb-cosasinb
box_pred_sin, reg_target_sin = self.add_sin_difference(box_pred, box_reg_target)
loc_loss_src = self.reg_loss_func(box_pred_sin, reg_target_sin, weights=reg_weight) # [N, M]
......@@ -263,8 +286,8 @@ class AnchorHeadMulti(AnchorHeadTemplate):
weights = positives.type_as(dir_logit)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
weight = weights[:, start_idx:start_idx+box_pred.shape[1]]
dir_target = dir_targets[:, start_idx:start_idx+box_pred.shape[1]]
weight = weights[:, start_idx:start_idx + box_pred.shape[1]]
dir_target = dir_targets[:, start_idx:start_idx + box_pred.shape[1]]
dir_loss = self.dir_loss_func(dir_logit, dir_target, weights=weight)
dir_loss = dir_loss.sum() / batch_size
dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight']
......
......@@ -18,12 +18,14 @@ class AnchorHeadTemplate(nn.Module):
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6)
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6),
**anchor_target_cfg.get('BOX_CODER_CONFIG', {})
)
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors(
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range,
anchor_ndim=self.box_coder.code_size
)
self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg)
......@@ -32,13 +34,20 @@ class AnchorHeadTemplate(nn.Module):
self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range):
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7):
anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg
)
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
if anchor_ndim != 7:
for idx, anchors in enumerate(anchors_list):
pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7])
new_anchors = torch.cat((anchors, pad_zeros), dim=-1)
anchors_list[idx] = new_anchors
return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg):
......@@ -65,9 +74,11 @@ class AnchorHeadTemplate(nn.Module):
'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
)
reg_loss_name = 'WeightedSmoothL1Loss' if losses_cfg.get('REG_LOSS_TYPE', None) is None \
else losses_cfg.REG_LOSS_TYPE
self.add_module(
'reg_loss_func',
loss_utils.WeightedSmoothL1Loss(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
getattr(loss_utils, reg_loss_name)(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
)
self.add_module(
'dir_loss_func',
......@@ -233,14 +244,17 @@ class AnchorHeadTemplate(nn.Module):
anchors = self.anchors
num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0]
batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() if not isinstance(cls_preds, list) else cls_preds
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() \
if not isinstance(cls_preds, list) else cls_preds
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) \
else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1)
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors)
if dir_cls_preds is not None:
dir_offset = self.model_cfg.DIR_OFFSET
dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1)
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) \
else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1)
dir_labels = torch.max(dir_cls_preds, dim=-1)[1]
period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS)
......
......@@ -6,7 +6,7 @@ from ..backbones_3d import vfe, pfe
from ..backbones_2d import map_to_bev
from ..model_utils.model_nms_utils import class_agnostic_nms
from ...ops.iou3d_nms import iou3d_nms_utils
import numpy as np
class Detector3DTemplate(nn.Module):
def __init__(self, model_cfg, num_class, dataset):
......@@ -170,7 +170,9 @@ class Detector3DTemplate(nn.Module):
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
has_class_labels: True/False
roi_labels: (B, num_rois) 1 .. num_classes
batch_pred_labels: (B, num_boxes, 1)
Returns:
"""
......@@ -182,46 +184,28 @@ class Detector3DTemplate(nn.Module):
if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_cls_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
if isinstance(batch_dict['batch_cls_preds'], list):
assert batch_dict['batch_cls_preds'][0].shape.__len__() == 3
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
batch_mask = index
box_preds = batch_dict['batch_box_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask] if not isinstance(batch_dict['batch_cls_preds'], list) else [batch_cls_pred[batch_mask] for batch_cls_pred in batch_dict['batch_cls_preds']]
cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
src_box_preds = box_preds
if isinstance(cls_preds, list):
assert cls_preds[0].shape[1] in [1, self.num_class]
else:
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds) if not isinstance(cls_preds, list) else [torch.sigmoid(cls_pred) for cls_pred in cls_preds]
cls_preds = torch.sigmoid(cls_preds)
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError
else:
if isinstance(cls_preds, list):
all_cls_preds = []
label_preds = []
rpn_head_cfgs = self.model_cfg.DENSE_HEAD.RPN_HEAD_CFGS
head_cls_names = [np.array(rpn_head_cfg['HEAD_CLS_NAME']) for rpn_head_cfg in rpn_head_cfgs]
for idx, cls_pred in enumerate(cls_preds):
pred_score, pred_head_label = torch.max(cls_pred, dim=-1)
pred_class_names = head_cls_names[idx][pred_head_label.cpu().numpy().astype(int)]
label_pred = [self.class_names.index(cls_name)+1 for cls_name in pred_class_names]
label_pred = torch.from_numpy(np.array(label_pred)).to(cls_pred.device).int()
all_cls_preds.append(pred_score)
label_preds.append(label_pred)
cls_preds = torch.cat(all_cls_preds, dim=0)
label_preds = torch.cat(label_preds, dim=0)
else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1)
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1
if batch_dict.get('has_class_labels', False):
label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
label_preds = batch_dict[label_key][index]
else:
label_preds + 1
selected, selected_scores = class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds,
......@@ -272,14 +256,14 @@ class Detector3DTemplate(nn.Module):
k -= 1
cur_gt = cur_gt[:k + 1]
if cur_gt.sum() > 0:
if cur_gt.shape[0] > 0:
if box_preds.shape[0] > 0:
iou3d_rcnn = iou3d_nms_utils.boxes_iou3d_gpu(box_preds, cur_gt[:, 0:7])
iou3d_rcnn = iou3d_nms_utils.boxes_iou3d_gpu(box_preds[:, 0:7], cur_gt[:, 0:7])
else:
iou3d_rcnn = torch.zeros((0, cur_gt.shape[0]))
if rois is not None:
iou3d_roi = iou3d_nms_utils.boxes_iou3d_gpu(rois, cur_gt[:, 0:7])
iou3d_roi = iou3d_nms_utils.boxes_iou3d_gpu(rois[:, 0:7], cur_gt[:, 0:7])
for cur_thresh in thresh_list:
if iou3d_rcnn.shape[0] == 0:
......
......@@ -74,27 +74,12 @@ MODEL:
RPN_HEAD_CFGS: [
{
'HEAD_CLS_NAME': ['Car'],
'LAYER_NUMS': [],
'LAYER_STRIDES': [],
'NUM_FILTERS': [],
'UPSAMPLE_STRIDES': [],
'NUM_UPSAMPLE_FILTERS': []
},
{
'HEAD_CLS_NAME': ['Pedestrian'],
'LAYER_NUMS': [],
'LAYER_STRIDES': [],
'NUM_FILTERS': [],
'UPSAMPLE_STRIDES': [],
'NUM_UPSAMPLE_FILTERS': []
},
{
'HEAD_CLS_NAME': ['Cyclist'],
'LAYER_NUMS': [],
'LAYER_STRIDES': [],
'NUM_FILTERS': [],
'UPSAMPLE_STRIDES': [],
'NUM_UPSAMPLE_FILTERS': []
}
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment