Commit 1a2ae138 authored by WXinlong's avatar WXinlong
Browse files

add two light-weight models

parent 28b3c6ea
# model settings
model = dict(
type='SOLO',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch',
dcn=dict(
type='DCN',
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='DecoupledSOLOLightHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
use_dcn_in_tower=True,
type_dcn='DCN',
seg_feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
cate_down_pos=0,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize',
img_scale=[(852, 512), (852, 480), (852, 448),
(852, 416), (852, 384), (852, 352)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(852, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/decoupled_solo_light_dcn_release_r50_fpn_8gpu_3x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='SOLO',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='DecoupledSOLOLightHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
seg_feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
cate_down_pos=0,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize',
img_scale=[(852, 512), (852, 480), (852, 448),
(852, 416), (852, 384), (852, 352)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(852, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/decoupled_solo_light_release_r50_fpn_8gpu_3x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -13,10 +13,11 @@ from .rpn_head import RPNHead ...@@ -13,10 +13,11 @@ from .rpn_head import RPNHead
from .ssd_head import SSDHead from .ssd_head import SSDHead
from .solo_head import SOLOHead from .solo_head import SOLOHead
from .decoupled_solo_head import DecoupledSOLOHead from .decoupled_solo_head import DecoupledSOLOHead
from .decoupled_solo_light_head import DecoupledSOLOLightHead
__all__ = [ __all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead', 'SOLOHead', 'DecoupledSOLOHead' 'ATSSHead', 'SOLOHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
] ]
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class DecoupledSOLOLightHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.4,
num_grids=None,
cate_down_pos=0,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None,
use_dcn_in_tower=False,
type_dcn=None):
super(DecoupledSOLOLightHead, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.cate_down_pos = cate_down_pos
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_dcn_in_tower = use_dcn_in_tower
self.type_dcn = type_dcn
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.ins_convs = nn.ModuleList()
self.cate_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.use_dcn_in_tower and i == self.stacked_convs - 1:
cfg_conv = dict(type=self.type_dcn)
else:
cfg_conv = self.conv_cfg
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.ins_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.dsolo_ins_list_x = nn.ModuleList()
self.dsolo_ins_list_y = nn.ModuleList()
for seg_num_grid in self.seg_num_grids:
self.dsolo_ins_list_x.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_ins_list_y.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
def init_weights(self):
for m in self.ins_convs:
normal_init(m.conv, std=0.01)
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
bias_ins = bias_init_with_prob(0.01)
for m in self.dsolo_ins_list_x:
normal_init(m, std=0.01, bias=bias_ins)
for m in self.dsolo_ins_list_y:
normal_init(m, std=0.01, bias=bias_ins)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.dsolo_cate, std=0.01, bias=bias_cate)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
ins_pred_x, ins_pred_y, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred_x, ins_pred_y, cate_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_feat = x
cate_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_feat = torch.cat([ins_feat, coord_feat], 1)
for ins_layer in self.ins_convs:
ins_feat = ins_layer(ins_feat)
ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
ins_pred_x = self.dsolo_ins_list_x[idx](ins_feat)
ins_pred_y = self.dsolo_ins_list_y[idx](ins_feat)
# cate branch
for i, cate_layer in enumerate(self.cate_convs):
if i == self.cate_down_pos:
seg_num_grid = self.seg_num_grids[idx]
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
cate_feat = cate_layer(cate_feat)
cate_pred = self.dsolo_cate(cate_feat)
if eval:
ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode='bilinear')
ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode='bilinear')
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return ins_pred_x, ins_pred_y, cate_pred
def loss(self,
ins_preds_x,
ins_preds_y,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds_x]
ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
# ins
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
ins_preds_x_final = [torch.cat([ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...]
for ins_preds_level_img_x, ins_ind_labels_level_img in
zip(ins_preds_level_x, ins_ind_labels_level)], 0)
for ins_preds_level_x, ins_ind_labels_level in
zip(ins_preds_x, zip(*ins_ind_label_list_xy))]
ins_preds_y_final = [torch.cat([ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...]
for ins_preds_level_img_y, ins_ind_labels_level_img in
zip(ins_preds_level_y, ins_ind_labels_level)], 0)
for ins_preds_level_y, ins_ind_labels_level in
zip(ins_preds_y, zip(*ins_ind_label_list_xy))]
num_ins = 0.
# dice loss
loss_ins = []
for input_x, input_y, target in zip(ins_preds_x_final, ins_preds_y_final, ins_labels):
mask_n = input_x.size(0)
if mask_n == 0:
continue
num_ins += mask_n
input = (input_x.sigmoid())*(input_y.sigmoid())
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean() * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
ins_ind_label_list_xy = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
ins_label = torch.zeros([num_grid**2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid**2], dtype=torch.bool, device=device)
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
if len(hit_indices) == 0:
ins_label = torch.zeros([1, featmap_size[0], featmap_size[1]], dtype=torch.uint8,
device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label = torch.zeros([1], dtype=torch.bool, device=device)
ins_ind_label_list.append(ins_ind_label)
ins_ind_label_list_xy.append(cate_label.nonzero())
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() < 10:
continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
# squared
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_ind_label[label] = True
ins_label = ins_label[ins_ind_label]
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label = ins_ind_label[ins_ind_label]
ins_ind_label_list.append(ins_ind_label)
ins_ind_label_list_xy.append(cate_label.nonzero())
return ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy
def get_seg(self, seg_preds_x, seg_preds_y, cate_preds, img_metas, cfg, rescale=None):
assert len(seg_preds_x) == len(cate_preds)
num_levels = len(cate_preds)
featmap_size = seg_preds_x[0].size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list_x = [
seg_preds_x[i][img_id].detach() for i in range(num_levels)
]
seg_pred_list_y = [
seg_preds_y[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0)
seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list_x, seg_pred_list_y,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds_x,
seg_preds_y,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# trans trans_diff.
trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()
trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
strides = torch.ones(trans_size[-1].item(), device=cate_preds.device)
n_stage = len(self.seg_num_grids)
trans_diff[:trans_size[0]] *= 0
seg_diff[:trans_size[0]] *= 0
num_grids[:trans_size[0]] *= self.seg_num_grids[0]
strides[:trans_size[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1]
seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1]
num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.strides[ind_]
# process.
inds = (cate_preds > cfg.score_thr)
cate_scores = cate_preds[inds]
inds = inds.nonzero()
trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
strides = torch.index_select(strides, dim=0, index=inds[:, 0])
y_inds = (inds[:, 0] - trans_diff) // num_grids
x_inds = (inds[:, 0] - trans_diff) % num_grids
y_inds += seg_diff
x_inds += seg_diff
cate_labels = inds[:, 1]
seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]
seg_masks = seg_masks_soft > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
keep = sum_masks > strides
seg_masks_soft = seg_masks_soft[keep, ...]
seg_masks = seg_masks[keep, ...]
cate_scores = cate_scores[keep]
sum_masks = sum_masks[keep]
cate_labels = cate_labels[keep]
# mask scoring
seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_score
if len(cate_scores) == 0:
return None
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
seg_masks = seg_masks[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
sum_masks = sum_masks[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
keep = cate_scores >= cfg.update_thr
seg_masks_soft = seg_masks_soft[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_masks_soft,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment