Commit 5efa3ac4 authored by WXinlong's avatar WXinlong
Browse files

update solo head

parent 5c6ad798
...@@ -307,7 +307,7 @@ class DecoupledSOLOHead(nn.Module): ...@@ -307,7 +307,7 @@ class DecoupledSOLOHead(nn.Module):
# mass center # mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device) gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt) center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10 valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2 output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags): for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
......
...@@ -302,7 +302,7 @@ class DecoupledSOLOLightHead(nn.Module): ...@@ -302,7 +302,7 @@ class DecoupledSOLOLightHead(nn.Module):
# mass center # mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device) gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt) center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10 valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2 output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags): for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
......
...@@ -279,7 +279,7 @@ class SOLOHead(nn.Module): ...@@ -279,7 +279,7 @@ class SOLOHead(nn.Module):
# mass center # mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device) gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt) center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10 valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2 output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags): for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment