Commit 5efa3ac4 authored by WXinlong's avatar WXinlong
Browse files

update solo head

parent 5c6ad798
......@@ -307,7 +307,7 @@ class DecoupledSOLOHead(nn.Module):
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
......
......@@ -302,7 +302,7 @@ class DecoupledSOLOLightHead(nn.Module):
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
......
......@@ -279,7 +279,7 @@ class SOLOHead(nn.Module):
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment