Unverified Commit 274c90c5 authored by jihan.yang's avatar jihan.yang Committed by GitHub
Browse files

fixbug: miss cls loss weight in center head (#772)

parent fa78c4f1
...@@ -232,6 +232,7 @@ class CenterHead(nn.Module): ...@@ -232,6 +232,7 @@ class CenterHead(nn.Module):
for idx, pred_dict in enumerate(pred_dicts): for idx, pred_dict in enumerate(pred_dicts):
pred_dict['hm'] = self.sigmoid(pred_dict['hm']) pred_dict['hm'] = self.sigmoid(pred_dict['hm'])
hm_loss = self.hm_loss_func(pred_dict['hm'], target_dicts['heatmaps'][idx]) hm_loss = self.hm_loss_func(pred_dict['hm'], target_dicts['heatmaps'][idx])
hm_loss *= self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
target_boxes = target_dicts['target_boxes'][idx] target_boxes = target_dicts['target_boxes'][idx]
pred_boxes = torch.cat([pred_dict[head_name] for head_name in self.separate_head_cfg.HEAD_ORDER], dim=1) pred_boxes = torch.cat([pred_dict[head_name] for head_name in self.separate_head_cfg.HEAD_ORDER], dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment