Commit 0ac2901c authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: loss calculation for separted multi-head

parent 19c66f79
...@@ -78,19 +78,26 @@ class SingleHead(BaseBEVBackbone): ...@@ -78,19 +78,26 @@ class SingleHead(BaseBEVBackbone):
class AnchorHeadMulti(AnchorHeadTemplate): class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True): def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True):
super().__init__( super().__init__(
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size,
point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
) )
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.seperate_multihead = self.model_cfg.get('SEPERATE_MULTIHEAD', False) self.separate_multihead = self.model_cfg.get('SEPARATE_MULTIHEAD', False)
shared_conv_num_filter = self.model_cfg.SHARED_CONV_NUM_FILTER if self.model_cfg.get('SHARED_CONV_NUM_FILTER', None) is not None:
self.shared_conv = nn.Sequential( shared_conv_num_filter = self.model_cfg.SHARED_CONV_NUM_FILTER
nn.Conv2d(input_channels, shared_conv_num_filter, 3, stride=1, padding=1, bias=False), self.shared_conv = nn.Sequential(
nn.BatchNorm2d(shared_conv_num_filter, eps=1e-3, momentum=0.01), nn.Conv2d(input_channels, shared_conv_num_filter, 3, stride=1, padding=1, bias=False),
nn.ReLU(), nn.BatchNorm2d(shared_conv_num_filter, eps=1e-3, momentum=0.01),
) nn.ReLU(),
)
else:
self.shared_conv = None
shared_conv_num_filter = input_channels
self.rpn_heads = None
self.make_multihead(shared_conv_num_filter) self.make_multihead(shared_conv_num_filter)
def make_multihead(self, input_channels): def make_multihead(self, input_channels):
...@@ -99,15 +106,22 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -99,15 +106,22 @@ class AnchorHeadMulti(AnchorHeadTemplate):
class_names = [] class_names = []
for rpn_head_cfg in rpn_head_cfgs: for rpn_head_cfg in rpn_head_cfgs:
class_names.extend(rpn_head_cfg['HEAD_CLS_NAME']) class_names.extend(rpn_head_cfg['HEAD_CLS_NAME'])
for rpn_head_cfg in rpn_head_cfgs: for rpn_head_cfg in rpn_head_cfgs:
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['HEAD_CLS_NAME']]) num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)]
rpn_head = SingleHead(self.model_cfg, input_channels, len(rpn_head_cfg['HEAD_CLS_NAME']) if self.seperate_multihead else self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg) for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
rpn_head = SingleHead(
self.model_cfg, input_channels,
len(rpn_head_cfg['HEAD_CLS_NAME']) if self.separate_multihead else self.num_class,
num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg
)
rpn_heads.append(rpn_head) rpn_heads.append(rpn_head)
self.rpn_heads = nn.ModuleList(rpn_heads) self.rpn_heads = nn.ModuleList(rpn_heads)
def forward(self, data_dict): def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d'] spatial_features_2d = data_dict['spatial_features_2d']
spatial_features_2d = self.shared_conv(spatial_features_2d) if self.shared_conv is not None:
spatial_features_2d = self.shared_conv(spatial_features_2d)
ret_dicts = [] ret_dicts = []
for rpn_head in self.rpn_heads: for rpn_head in self.rpn_heads:
...@@ -115,15 +129,14 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -115,15 +129,14 @@ class AnchorHeadMulti(AnchorHeadTemplate):
cls_preds = [ret_dict['cls_preds'] for ret_dict in ret_dicts] cls_preds = [ret_dict['cls_preds'] for ret_dict in ret_dicts]
box_preds = [ret_dict['box_preds'] for ret_dict in ret_dicts] box_preds = [ret_dict['box_preds'] for ret_dict in ret_dicts]
ret = { ret = {
'cls_preds': cls_preds if self.seperate_multihead else torch.cat(cls_preds, dim=1), 'cls_preds': cls_preds if self.separate_multihead else torch.cat(cls_preds, dim=1),
'box_preds': box_preds if self.seperate_multihead else torch.cat(box_preds, dim=1), 'box_preds': box_preds if self.separate_multihead else torch.cat(box_preds, dim=1),
} }
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False): if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False):
dir_cls_preds = [ret_dict['dir_cls_preds'] for ret_dict in ret_dicts] dir_cls_preds = [ret_dict['dir_cls_preds'] for ret_dict in ret_dicts]
ret['dir_cls_preds'] = dir_cls_preds if self.seperate_multihead else torch.cat(dir_cls_preds, dim=1) ret['dir_cls_preds'] = dir_cls_preds if self.separate_multihead else torch.cat(dir_cls_preds, dim=1)
else: else:
dir_cls_preds = None dir_cls_preds = None
...@@ -161,25 +174,33 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -161,25 +174,33 @@ class AnchorHeadMulti(AnchorHeadTemplate):
# class agnostic # class agnostic
box_cls_labels[positives] = 1 box_cls_labels[positives] = 1
pos_normalizer = positives.sum(1, keepdim=True).float() pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0) reg_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_weights /= torch.clamp(pos_normalizer, min=1.0) cls_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_targets = box_cls_labels * cared.type_as(box_cls_labels) cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
one_hot_target = torch.zeros( one_hot_targets = torch.zeros(
*list(cls_targets.shape), cls_preds[0].shape[-1] + 1 if self.seperate_multihead else self.num_class + 1, dtype=cls_preds[0].dtype, device=cls_targets.device *list(cls_targets.shape), self.num_class + 1, dtype=cls_preds[0].dtype, device=cls_targets.device
) )
one_hot_target.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0) one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_target[..., 1:] one_hot_targets = one_hot_targets[..., 1:]
start_idx = 0 start_idx = c_idx = 0
cls_losses = 0 cls_losses = 0
for cls_pred in cls_preds:
cls_pred = cls_pred.view(batch_size, -1, cls_pred.shape[-1]) for idx, cls_pred in enumerate(cls_preds):
one_hot_target = one_hot_targets[:, start_idx:start_idx+cls_pred.shape[1]] cur_num_class = self.rpn_heads[idx].num_class
cls_pred = cls_pred.view(batch_size, -1, cur_num_class)
if self.separate_multihead:
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1], c_idx:c_idx+cur_num_class]
c_idx += cur_num_class
else:
one_hot_target = one_hot_targets[:, start_idx:start_idx+cls_pred.shape[1]]
cls_weight = cls_weights[:, start_idx:start_idx+cls_pred.shape[1]] cls_weight = cls_weights[:, start_idx:start_idx+cls_pred.shape[1]]
cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M] cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight'] cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
cls_losses += cls_loss cls_losses += cls_loss
start_idx += cls_pred.shape[1] start_idx += cls_pred.shape[1]
assert start_idx == one_hot_targets.shape[1]
tb_dict = { tb_dict = {
'rpn_loss_cls': cls_losses.item() 'rpn_loss_cls': cls_losses.item()
} }
......
...@@ -35,7 +35,7 @@ MODEL: ...@@ -35,7 +35,7 @@ MODEL:
NUM_DIR_BINS: 2 NUM_DIR_BINS: 2
USE_MULTIHEAD: True USE_MULTIHEAD: True
SEPERATE_MULTIHEAD: True SEPARATE_MULTIHEAD: True
ANCHOR_GENERATOR_CONFIG: [ ANCHOR_GENERATOR_CONFIG: [
{ {
'class_name': 'Car', 'class_name': 'Car',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment