Commit cb32ff13 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

set pos_cls_weight=1.0, neg_cls_weight=2.0 for NuScenes configs

parent 0dd5326e
...@@ -181,6 +181,13 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -181,6 +181,13 @@ class AnchorHeadMulti(AnchorHeadTemplate):
return data_dict return data_dict
def get_cls_layer_loss(self): def get_cls_layer_loss(self):
loss_weights = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
if 'pos_cls_weight' in loss_weights:
pos_cls_weight = loss_weights['pos_cls_weight']
neg_cls_weight = loss_weights['neg_cls_weight']
else:
pos_cls_weight = neg_cls_weight = 1.0
cls_preds = self.forward_ret_dict['cls_preds'] cls_preds = self.forward_ret_dict['cls_preds']
box_cls_labels = self.forward_ret_dict['box_cls_labels'] box_cls_labels = self.forward_ret_dict['box_cls_labels']
if not isinstance(cls_preds, list): if not isinstance(cls_preds, list):
...@@ -189,8 +196,10 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -189,8 +196,10 @@ class AnchorHeadMulti(AnchorHeadTemplate):
cared = box_cls_labels >= 0 # [N, num_anchors] cared = box_cls_labels >= 0 # [N, num_anchors]
positives = box_cls_labels > 0 positives = box_cls_labels > 0
negatives = box_cls_labels == 0 negatives = box_cls_labels == 0
negative_cls_weights = negatives * 1.0 negative_cls_weights = negatives * 1.0 * neg_cls_weight
cls_weights = (negative_cls_weights + 1.0 * positives).float()
cls_weights = (negative_cls_weights + pos_cls_weight * positives).float()
reg_weights = positives.float() reg_weights = positives.float()
if self.num_class == 1: if self.num_class == 1:
# class agnostic # class agnostic
...@@ -220,7 +229,7 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -220,7 +229,7 @@ class AnchorHeadMulti(AnchorHeadTemplate):
cls_weight = cls_weights[:, start_idx:start_idx + cls_pred.shape[1]] cls_weight = cls_weights[:, start_idx:start_idx + cls_pred.shape[1]]
cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M] cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight'] cls_loss = cls_loss * loss_weights['cls_weight']
cls_losses += cls_loss cls_losses += cls_loss
start_idx += cls_pred.shape[1] start_idx += cls_pred.shape[1]
assert start_idx == one_hot_targets.shape[1] assert start_idx == one_hot_targets.shape[1]
......
...@@ -196,6 +196,8 @@ MODEL: ...@@ -196,6 +196,8 @@ MODEL:
LOSS_CONFIG: LOSS_CONFIG:
REG_LOSS_TYPE: WeightedL1Loss REG_LOSS_TYPE: WeightedL1Loss
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'pos_cls_weight': 1.0,
'neg_cls_weight': 2.0,
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 0.25, 'loc_weight': 0.25,
'dir_weight': 0.2, 'dir_weight': 0.2,
......
...@@ -196,6 +196,8 @@ MODEL: ...@@ -196,6 +196,8 @@ MODEL:
LOSS_CONFIG: LOSS_CONFIG:
REG_LOSS_TYPE: WeightedL1Loss REG_LOSS_TYPE: WeightedL1Loss
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'pos_cls_weight': 1.0,
'neg_cls_weight': 2.0,
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 0.25, 'loc_weight': 0.25,
'dir_weight': 0.2, 'dir_weight': 0.2,
......
...@@ -214,6 +214,8 @@ MODEL: ...@@ -214,6 +214,8 @@ MODEL:
LOSS_CONFIG: LOSS_CONFIG:
REG_LOSS_TYPE: WeightedL1Loss REG_LOSS_TYPE: WeightedL1Loss
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'pos_cls_weight': 1.0,
'neg_cls_weight': 2.0,
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 0.25, 'loc_weight': 0.25,
'dir_weight': 0.2, 'dir_weight': 0.2,
......
...@@ -152,6 +152,8 @@ MODEL: ...@@ -152,6 +152,8 @@ MODEL:
LOSS_CONFIG: LOSS_CONFIG:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'pos_cls_weight': 1.0,
'neg_cls_weight': 2.0,
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 2.0, 'loc_weight': 2.0,
'dir_weight': 0.2, 'dir_weight': 0.2,
......
...@@ -204,6 +204,8 @@ MODEL: ...@@ -204,6 +204,8 @@ MODEL:
LOSS_CONFIG: LOSS_CONFIG:
LOSS_WEIGHTS: { LOSS_WEIGHTS: {
'pos_cls_weight': 1.0,
'neg_cls_weight': 2.0,
'cls_weight': 1.0, 'cls_weight': 1.0,
'loc_weight': 2.0, 'loc_weight': 2.0,
'dir_weight': 0.2, 'dir_weight': 0.2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment