Commit b00170f1 authored by WXinlong's avatar WXinlong
Browse files

Update head.

parent 51d8eb73
...@@ -441,7 +441,7 @@ class DecoupledSOLOHead(nn.Module): ...@@ -441,7 +441,7 @@ class DecoupledSOLOHead(nn.Module):
cate_scores = cate_scores[keep] cate_scores = cate_scores[keep]
sum_masks = sum_masks[keep] sum_masks = sum_masks[keep]
cate_labels = cate_labels[keep] cate_labels = cate_labels[keep]
# mask scoring # maskness
seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_score cate_scores *= seg_score
......
...@@ -436,7 +436,7 @@ class DecoupledSOLOLightHead(nn.Module): ...@@ -436,7 +436,7 @@ class DecoupledSOLOLightHead(nn.Module):
cate_scores = cate_scores[keep] cate_scores = cate_scores[keep]
sum_masks = sum_masks[keep] sum_masks = sum_masks[keep]
cate_labels = cate_labels[keep] cate_labels = cate_labels[keep]
# mask scoring # maskness
seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_score cate_scores *= seg_score
......
...@@ -389,7 +389,7 @@ class SOLOHead(nn.Module): ...@@ -389,7 +389,7 @@ class SOLOHead(nn.Module):
cate_scores = cate_scores[keep] cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep] cate_labels = cate_labels[keep]
# mask scoring. # maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores cate_scores *= seg_scores
......
...@@ -439,7 +439,7 @@ class SOLOv2Head(nn.Module): ...@@ -439,7 +439,7 @@ class SOLOv2Head(nn.Module):
cate_scores = cate_scores[keep] cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep] cate_labels = cate_labels[keep]
# mask scoring. # maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores cate_scores *= seg_scores
......
...@@ -438,7 +438,7 @@ class SOLOv2LightHead(nn.Module): ...@@ -438,7 +438,7 @@ class SOLOv2LightHead(nn.Module):
cate_scores = cate_scores[keep] cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep] cate_labels = cate_labels[keep]
# mask scoring. # maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores cate_scores *= seg_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment