Unverified Commit 4c0c4c26 authored by Xinlong Wang's avatar Xinlong Wang Committed by GitHub
Browse files

Merge pull request #131 from WXinlong/gt_gen

Update target generation
parents be36de16 5efa3ac4
...@@ -24,6 +24,7 @@ More code and models will be released soon. Stay tuned. ...@@ -24,6 +24,7 @@ More code and models will be released soon. Stay tuned.
- **State-of-the-art performance:** Our best single model based on ResNet-101 and deformable convolutions achieves **41.7%** in AP on COCO test-dev (without multi-scale testing). A light-weight version of SOLOv2 executes at **31.3** FPS on a single V100 GPU and yields **37.1%** AP. - **State-of-the-art performance:** Our best single model based on ResNet-101 and deformable convolutions achieves **41.7%** in AP on COCO test-dev (without multi-scale testing). A light-weight version of SOLOv2 executes at **31.3** FPS on a single V100 GPU and yields **37.1%** AP.
## Updates ## Updates
- Training speeds up (~1.7x faster) for all models. (03/12/20)
- SOLOv2 is available. Code and trained models of SOLOv2 are released. (08/07/2020) - SOLOv2 is available. Code and trained models of SOLOv2 are released. (08/07/2020)
- Light-weight models and R101-based models are available. (31/03/2020) - Light-weight models and R101-based models are available. (31/03/2020)
- SOLOv1 is available. Code and trained models of SOLO and Decoupled SOLO are released. (28/03/2020) - SOLOv1 is available. Code and trained models of SOLO and Decoupled SOLO are released. (28/03/2020)
......
...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule ...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8 INF = 1e8
from scipy import ndimage def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2): def points_nms(heat, kernel=2):
# kernel must be 2 # kernel must be 2
...@@ -294,15 +304,16 @@ class DecoupledSOLOHead(nn.Module): ...@@ -294,15 +304,16 @@ class DecoupledSOLOHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2 # mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws): output_stride = stride / 2
if seg_mask.sum() < 10: for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid)) coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
...@@ -321,7 +332,7 @@ class DecoupledSOLOHead(nn.Module): ...@@ -321,7 +332,7 @@ class DecoupledSOLOHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label cate_label[top:(down+1), left:(right+1)] = gt_label
# ins # ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask) seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1): for i in range(top, down+1):
for j in range(left, right+1): for j in range(left, right+1):
label = int(i * num_grid + j) label = int(i * num_grid + j)
......
...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule ...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8 INF = 1e8
from scipy import ndimage def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2): def points_nms(heat, kernel=2):
# kernel must be 2 # kernel must be 2
...@@ -289,15 +299,16 @@ class DecoupledSOLOLightHead(nn.Module): ...@@ -289,15 +299,16 @@ class DecoupledSOLOLightHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2 # mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws): output_stride = stride / 2
if seg_mask.sum() < 10: for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid)) coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
...@@ -316,7 +327,7 @@ class DecoupledSOLOLightHead(nn.Module): ...@@ -316,7 +327,7 @@ class DecoupledSOLOLightHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label cate_label[top:(down+1), left:(right+1)] = gt_label
# ins # ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask) seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1): for i in range(top, down+1):
for j in range(left, right+1): for j in range(left, right+1):
label = int(i * num_grid + j) label = int(i * num_grid + j)
......
...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule ...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8 INF = 1e8
from scipy import ndimage def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2): def points_nms(heat, kernel=2):
# kernel must be 2 # kernel must be 2
...@@ -266,14 +276,16 @@ class SOLOHead(nn.Module): ...@@ -266,14 +276,16 @@ class SOLOHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2 # mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws): output_stride = stride / 2
if seg_mask.sum() < 10: for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid)) coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
...@@ -291,7 +303,7 @@ class SOLOHead(nn.Module): ...@@ -291,7 +303,7 @@ class SOLOHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label cate_label[top:(down+1), left:(right+1)] = gt_label
# ins # ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask) seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1): for i in range(top, down+1):
for j in range(left, right+1): for j in range(left, right+1):
label = int(i * num_grid + j) label = int(i * num_grid + j)
......
...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule ...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8 INF = 1e8
from scipy import ndimage def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2): def points_nms(heat, kernel=2):
# kernel must be 2 # kernel must be 2
...@@ -299,13 +309,16 @@ class SOLOv2Head(nn.Module): ...@@ -299,13 +309,16 @@ class SOLOv2Head(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4 output_stride = 4
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws): for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if seg_mask.sum() == 0: if not valid_mask_flag:
continue continue
# mass center
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4) upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid)) coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
...@@ -322,7 +335,7 @@ class SOLOv2Head(nn.Module): ...@@ -322,7 +335,7 @@ class SOLOv2Head(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask) seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1): for i in range(top, down+1):
for j in range(left, right+1): for j in range(left, right+1):
label = int(i * num_grid + j) label = int(i * num_grid + j)
...@@ -333,8 +346,10 @@ class SOLOv2Head(nn.Module): ...@@ -333,8 +346,10 @@ class SOLOv2Head(nn.Module):
ins_label.append(cur_ins_label) ins_label.append(cur_ins_label)
ins_ind_label[label] = True ins_ind_label[label] = True
grid_order.append(label) grid_order.append(label)
ins_label = torch.stack(ins_label, 0) if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label) ins_label_list.append(ins_label)
cate_label_list.append(cate_label) cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label) ins_ind_label_list.append(ins_ind_label)
......
...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule ...@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8 INF = 1e8
from scipy import ndimage def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2): def points_nms(heat, kernel=2):
# kernel must be 2 # kernel must be 2
...@@ -299,13 +309,15 @@ class SOLOv2LightHead(nn.Module): ...@@ -299,13 +309,15 @@ class SOLOv2LightHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4 output_stride = 4
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws): for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if seg_mask.sum() == 0: if not valid_mask_flag:
continue continue
# mass center
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4) upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid)) coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
...@@ -322,7 +334,7 @@ class SOLOv2LightHead(nn.Module): ...@@ -322,7 +334,7 @@ class SOLOv2LightHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask) seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1): for i in range(top, down+1):
for j in range(left, right+1): for j in range(left, right+1):
label = int(i * num_grid + j) label = int(i * num_grid + j)
...@@ -333,8 +345,10 @@ class SOLOv2LightHead(nn.Module): ...@@ -333,8 +345,10 @@ class SOLOv2LightHead(nn.Module):
ins_label.append(cur_ins_label) ins_label.append(cur_ins_label)
ins_ind_label[label] = True ins_ind_label[label] = True
grid_order.append(label) grid_order.append(label)
ins_label = torch.stack(ins_label, 0) if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label) ins_label_list.append(ins_label)
cate_label_list.append(cate_label) cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label) ins_ind_label_list.append(ins_ind_label)
...@@ -406,12 +420,11 @@ class SOLOv2LightHead(nn.Module): ...@@ -406,12 +420,11 @@ class SOLOv2LightHead(nn.Module):
strides = strides[inds[:, 0]] strides = strides[inds[:, 0]]
# mask encoding. # mask encoding.
N, I = kernel_preds.shape I, N = kernel_preds.shape
kernel_preds = kernel_preds.view(N, I, 1, 1) kernel_preds = kernel_preds.view(I, N, 1, 1)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid() seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
# mask. # mask.
seg_masks = seg_preds > 0.5 seg_masks = seg_preds > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float() sum_masks = seg_masks.sum((1, 2)).float()
# filter. # filter.
...@@ -465,6 +478,5 @@ class SOLOv2LightHead(nn.Module): ...@@ -465,6 +478,5 @@ class SOLOv2LightHead(nn.Module):
seg_masks = F.interpolate(seg_preds, seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2], size=ori_shape[:2],
mode='bilinear').squeeze(0) mode='bilinear').squeeze(0)
seg_masks = seg_masks > 0.5 seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores return seg_masks, cate_labels, cate_scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment