Commit 5c6ad798 authored by WXinlong's avatar WXinlong
Browse files

Update target generation

parent be36de16
......@@ -24,6 +24,7 @@ More code and models will be released soon. Stay tuned.
- **State-of-the-art performance:** Our best single model based on ResNet-101 and deformable convolutions achieves **41.7%** in AP on COCO test-dev (without multi-scale testing). A light-weight version of SOLOv2 executes at **31.3** FPS on a single V100 GPU and yields **37.1%** AP.
## Updates
- Training speeds up (~1.7x faster) for all models. (03/12/20)
- SOLOv2 is available. Code and trained models of SOLOv2 are released. (08/07/2020)
- Light-weight models and R101-based models are available. (31/03/2020)
- SOLOv1 is available. Code and trained models of SOLO and Decoupled SOLO are released. (28/03/2020)
......
......@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
......@@ -294,15 +304,16 @@ class DecoupledSOLOHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() < 10:
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
......@@ -321,7 +332,7 @@ class DecoupledSOLOHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
......
......@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
......@@ -289,15 +299,16 @@ class DecoupledSOLOLightHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() < 10:
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
......@@ -316,7 +327,7 @@ class DecoupledSOLOLightHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
......
......@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
......@@ -266,14 +276,16 @@ class SOLOHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) >= 10
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() < 10:
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
......@@ -291,7 +303,7 @@ class SOLOHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
......
......@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
......@@ -299,13 +309,16 @@ class SOLOv2Head(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() == 0:
continue
# mass center
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
......@@ -322,7 +335,7 @@ class SOLOv2Head(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
......@@ -333,8 +346,10 @@ class SOLOv2Head(nn.Module):
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
ins_label = torch.stack(ins_label, 0)
if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
......
......@@ -11,7 +11,17 @@ from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
......@@ -299,13 +309,15 @@ class SOLOv2LightHead(nn.Module):
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() == 0:
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
# mass center
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
......@@ -322,7 +334,7 @@ class SOLOv2LightHead(nn.Module):
cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
......@@ -333,8 +345,10 @@ class SOLOv2LightHead(nn.Module):
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
ins_label = torch.stack(ins_label, 0)
if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
......@@ -406,12 +420,11 @@ class SOLOv2LightHead(nn.Module):
strides = strides[inds[:, 0]]
# mask encoding.
N, I = kernel_preds.shape
kernel_preds = kernel_preds.view(N, I, 1, 1)
I, N = kernel_preds.shape
kernel_preds = kernel_preds.view(I, N, 1, 1)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
# mask.
seg_masks = seg_preds > 0.5
seg_masks = seg_preds > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
# filter.
......@@ -465,6 +478,5 @@ class SOLOv2LightHead(nn.Module):
seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > 0.5
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment