Commit 451933f7 authored by WXinlong's avatar WXinlong
Browse files

add solov2

parent d5398a0d
# model settings
model = dict(
type='SOLOv2',
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch',
dcn=dict(
type='DCNv2',
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='SOLOv2Head',
num_classes=81,
in_channels=256,
stacked_convs=4,
use_dcn_in_tower=True,
type_dcn='DCNv2',
seg_feat_channels=512,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
ins_out_channels=256,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)),
mask_feat_head=dict(
type='MaskFeatHead',
in_channels=256,
out_channels=128,
start_level=0,
end_level=3,
num_classes=256,
conv_cfg=dict(type='DCNv2'),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
)
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize',
img_scale=[(1333, 800), (1333, 768), (1333, 736),
(1333, 704), (1333, 672), (1333, 640)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.01,
step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solov2_release_r101_dcn_fpn_8gpu_3x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='SOLOv2',
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='SOLOv2Head',
num_classes=81,
in_channels=256,
stacked_convs=4,
seg_feat_channels=512,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
ins_out_channels=256,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)),
mask_feat_head=dict(
type='MaskFeatHead',
in_channels=256,
out_channels=128,
start_level=0,
end_level=3,
num_classes=256,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
)
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize',
img_scale=[(1333, 800), (1333, 768), (1333, 736),
(1333, 704), (1333, 672), (1333, 640)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.01,
step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solov2_release_r101_fpn_8gpu_3x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='SOLOv2',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='SOLOv2Head',
num_classes=81,
in_channels=256,
stacked_convs=4,
seg_feat_channels=512,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
ins_out_channels=256,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)),
mask_feat_head=dict(
type='MaskFeatHead',
in_channels=256,
out_channels=128,
start_level=0,
end_level=3,
num_classes=256,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
)
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.01,
step=[9, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solov2_release_r50_fpn_8gpu_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='SOLOv2',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='SOLOv2Head',
num_classes=81,
in_channels=256,
stacked_convs=4,
seg_feat_channels=512,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
ins_out_channels=256,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)),
mask_feat_head=dict(
type='MaskFeatHead',
in_channels=256,
out_channels=128,
start_level=0,
end_level=3,
num_classes=256,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
)
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize',
img_scale=[(1333, 800), (1333, 768), (1333, 736),
(1333, 704), (1333, 672), (1333, 640)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.01,
step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solov2_release_r50_fpn_8gpu_3x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='SOLOv2',
pretrained='open-mmlab://resnext101_64x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
dcn=dict(
type='DCNv2',
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='SOLOv2Head',
num_classes=81,
in_channels=256,
stacked_convs=4,
use_dcn_in_tower=True,
type_dcn='DCNv2',
seg_feat_channels=512,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
ins_out_channels=256,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)),
mask_feat_head=dict(
type='MaskFeatHead',
in_channels=256,
out_channels=128,
start_level=0,
end_level=3,
num_classes=256,
conv_cfg=dict(type='DCNv2'),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
)
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize',
img_scale=[(1333, 800), (1333, 768), (1333, 736),
(1333, 704), (1333, 672), (1333, 640)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.01,
step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solov2_release_x101_dcn_fpn_8gpu_3x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -12,6 +12,7 @@ from .retina_sepbn_head import RetinaSepBNHead ...@@ -12,6 +12,7 @@ from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead from .rpn_head import RPNHead
from .ssd_head import SSDHead from .ssd_head import SSDHead
from .solo_head import SOLOHead from .solo_head import SOLOHead
from .solov2_head import SOLOv2Head
from .decoupled_solo_head import DecoupledSOLOHead from .decoupled_solo_head import DecoupledSOLOHead
from .decoupled_solo_light_head import DecoupledSOLOLightHead from .decoupled_solo_light_head import DecoupledSOLOLightHead
...@@ -19,5 +20,5 @@ __all__ = [ ...@@ -19,5 +20,5 @@ __all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead', 'SOLOHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead' 'ATSSHead', 'SOLOHead', 'SOLOv2Head', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
] ]
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
from scipy import ndimage
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class SOLOv2Head(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.2,
num_grids=None,
ins_out_channels=64,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None,
use_dcn_in_tower=False,
type_dcn=None):
super(SOLOv2Head, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.ins_out_channels = ins_out_channels
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.stacked_convs = stacked_convs
self.kernel_out_channels = self.ins_out_channels * 1 * 1
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_dcn_in_tower = use_dcn_in_tower
self.type_dcn = type_dcn
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.cate_convs = nn.ModuleList()
self.kernel_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.use_dcn_in_tower:
cfg_conv = dict(type=self.type_dcn)
else:
cfg_conv = self.conv_cfg
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.kernel_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.solo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
self.solo_kernel = nn.Conv2d(
self.seg_feat_channels, self.kernel_out_channels, 3, padding=1)
def init_weights(self):
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
for m in self.kernel_convs:
normal_init(m.conv, std=0.01)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.solo_cate, std=0.01, bias=bias_cate)
normal_init(self.solo_kernel, std=0.01)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
cate_pred, kernel_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return cate_pred, kernel_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_kernel_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-1], device=ins_kernel_feat.device)
y_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-2], device=ins_kernel_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_kernel_feat.shape[0], 1, -1, -1])
x = x.expand([ins_kernel_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1)
# kernel branch
kernel_feat = ins_kernel_feat
seg_num_grid = self.seg_num_grids[idx]
kernel_feat = F.interpolate(kernel_feat, size=seg_num_grid, mode='bilinear')
cate_feat = kernel_feat[:, :-2, :, :]
kernel_feat = kernel_feat.contiguous()
for i, kernel_layer in enumerate(self.kernel_convs):
kernel_feat = kernel_layer(kernel_feat)
kernel_pred = self.solo_kernel(kernel_feat)
# cate branch
cate_feat = cate_feat.contiguous()
for i, cate_layer in enumerate(self.cate_convs):
cate_feat = cate_layer(cate_feat)
cate_pred = self.solo_cate(cate_feat)
if eval:
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return cate_pred, kernel_pred
def loss(self,
cate_preds,
kernel_preds,
ins_pred,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
mask_feat_size = ins_pred.size()[-2:]
ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = multi_apply(
self.solov2_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
mask_feat_size=mask_feat_size)
# ins
ins_labels = [torch.cat([ins_labels_level_img
for ins_labels_level_img in ins_labels_level], 0)
for ins_labels_level in zip(*ins_label_list)]
kernel_preds = [[kernel_preds_level_img.view(kernel_preds_level_img.shape[0], -1)[:, grid_orders_level_img]
for kernel_preds_level_img, grid_orders_level_img in
zip(kernel_preds_level, grid_orders_level)]
for kernel_preds_level, grid_orders_level in zip(kernel_preds, zip(*grid_order_list))]
# generate masks
ins_pred = ins_pred
ins_pred_list = []
for b_kernel_pred in kernel_preds:
b_mask_pred = []
for idx, kernel_pred in enumerate(b_kernel_pred):
if kernel_pred.size()[-1] == 0:
continue
cur_ins_pred = ins_pred[idx, ...]
H, W = cur_ins_pred.shape[-2:]
N, I = kernel_pred.shape
cur_ins_pred = cur_ins_pred.unsqueeze(0)
kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1)
cur_ins_pred = F.conv2d(cur_ins_pred, kernel_pred, stride=1).view(-1, H, W)
b_mask_pred.append(cur_ins_pred)
if len(b_mask_pred) == 0:
b_mask_pred = None
else:
b_mask_pred = torch.cat(b_mask_pred, 0)
ins_pred_list.append(b_mask_pred)
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum()
# dice loss
loss_ins = []
for input, target in zip(ins_pred_list, ins_labels):
if input is None:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solov2_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
mask_feat_size):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
grid_order_list = []
for (lower_bound, upper_bound), stride, num_grid \
in zip(self.scale_ranges, self.strides, self.seg_num_grids):
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
if num_ins == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append([])
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = 4
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() == 0:
continue
# mass center
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
cur_ins_label = torch.zeros([mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8,
device=device)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append(grid_order)
return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list
def get_seg(self, cate_preds, kernel_preds, seg_pred, img_metas, cfg, rescale=None):
num_levels = len(cate_preds)
featmap_size = seg_pred.size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list = seg_pred[img_id, ...].unsqueeze(0)
kernel_pred_list = [
kernel_preds[i][img_id].permute(1, 2, 0).view(-1, self.kernel_out_channels).detach()
for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
kernel_pred_list = torch.cat(kernel_pred_list, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list, kernel_pred_list,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds,
kernel_preds,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
assert len(cate_preds) == len(kernel_preds)
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# process.
inds = (cate_preds > cfg.score_thr)
cate_scores = cate_preds[inds]
if len(cate_scores) == 0:
return None
# cate_labels & kernel_preds
inds = inds.nonzero()
cate_labels = inds[:, 1]
kernel_preds = kernel_preds[inds[:, 0]]
# trans vector.
size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
strides = kernel_preds.new_ones(size_trans[-1])
n_stage = len(self.seg_num_grids)
strides[:size_trans[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
strides[size_trans[ind_-1]:size_trans[ind_]] *= self.strides[ind_]
strides = strides[inds[:, 0]]
# mask encoding.
I, N = kernel_preds.shape
kernel_preds = kernel_preds.view(I, N, 1, 1)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
# mask.
seg_masks = seg_preds > 0.5
sum_masks = seg_masks.sum((1, 2)).float()
# filter.
keep = sum_masks > strides
if keep.sum() == 0:
return None
seg_masks = seg_masks[keep, ...]
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# mask scoring.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel,sigma=cfg.sigma, sum_masks=sum_masks)
# filter.
keep = cate_scores >= cfg.update_thr
if keep.sum() == 0:
return None
seg_preds = seg_preds[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_preds = F.interpolate(seg_preds.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > 0.5
return seg_masks, cate_labels, cate_scores
...@@ -17,10 +17,11 @@ from .single_stage import SingleStageDetector ...@@ -17,10 +17,11 @@ from .single_stage import SingleStageDetector
from .single_stage_ins import SingleStageInsDetector from .single_stage_ins import SingleStageInsDetector
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from .solo import SOLO from .solo import SOLO
from .solov2 import SOLOv2
__all__ = [ __all__ = [
'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN',
'RepPointsDetector', 'FOVEA', 'SingleStageInsDetector', 'SOLO' 'RepPointsDetector', 'FOVEA', 'SingleStageInsDetector', 'SOLO', 'SOLOv2'
] ]
...@@ -20,6 +20,11 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): ...@@ -20,6 +20,11 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
def with_neck(self): def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None return hasattr(self, 'neck') and self.neck is not None
@property
def with_mask_feat_head(self):
return hasattr(self, 'mask_feat_head') and \
self.mask_feat_head is not None
@property @property
def with_shared_head(self): def with_shared_head(self):
return hasattr(self, 'shared_head') and self.shared_head is not None return hasattr(self, 'shared_head') and self.shared_head is not None
......
...@@ -13,6 +13,7 @@ class SingleStageInsDetector(BaseDetector): ...@@ -13,6 +13,7 @@ class SingleStageInsDetector(BaseDetector):
backbone, backbone,
neck=None, neck=None,
bbox_head=None, bbox_head=None,
mask_feat_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None):
...@@ -20,6 +21,9 @@ class SingleStageInsDetector(BaseDetector): ...@@ -20,6 +21,9 @@ class SingleStageInsDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
if mask_feat_head is not None:
self.mask_feat_head = builder.build_head(mask_feat_head)
self.bbox_head = builder.build_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -34,6 +38,12 @@ class SingleStageInsDetector(BaseDetector): ...@@ -34,6 +38,12 @@ class SingleStageInsDetector(BaseDetector):
m.init_weights() m.init_weights()
else: else:
self.neck.init_weights() self.neck.init_weights()
if self.with_mask_feat_head:
if isinstance(self.mask_feat_head, nn.Sequential):
for m in self.mask_feat_head:
m.init_weights()
else:
self.mask_feat_head.init_weights()
self.bbox_head.init_weights() self.bbox_head.init_weights()
def extract_feat(self, img): def extract_feat(self, img):
...@@ -56,7 +66,14 @@ class SingleStageInsDetector(BaseDetector): ...@@ -56,7 +66,14 @@ class SingleStageInsDetector(BaseDetector):
gt_masks=None): gt_masks=None):
x = self.extract_feat(img) x = self.extract_feat(img)
outs = self.bbox_head(x) outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
if self.with_mask_feat_head:
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.
start_level:self.mask_feat_head.end_level + 1])
loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
losses = self.bbox_head.loss( losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses return losses
...@@ -64,7 +81,14 @@ class SingleStageInsDetector(BaseDetector): ...@@ -64,7 +81,14 @@ class SingleStageInsDetector(BaseDetector):
def simple_test(self, img, img_meta, rescale=False): def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img) x = self.extract_feat(img)
outs = self.bbox_head(x, eval=True) outs = self.bbox_head(x, eval=True)
seg_inputs = outs + (img_meta, self.test_cfg, rescale)
if self.with_mask_feat_head:
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.
start_level:self.mask_feat_head.end_level + 1])
seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
else:
seg_inputs = outs + (img_meta, self.test_cfg, rescale)
seg_result = self.bbox_head.get_seg(*seg_inputs) seg_result = self.bbox_head.get_seg(*seg_inputs)
return seg_result return seg_result
......
...@@ -12,5 +12,5 @@ class SOLO(SingleStageInsDetector): ...@@ -12,5 +12,5 @@ class SOLO(SingleStageInsDetector):
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None):
super(SOLO, self).__init__(backbone, neck, bbox_head, train_cfg, super(SOLO, self).__init__(backbone, neck, bbox_head, None, train_cfg,
test_cfg, pretrained) test_cfg, pretrained)
from .single_stage_ins import SingleStageInsDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class SOLOv2(SingleStageInsDetector):
def __init__(self,
backbone,
neck,
bbox_head,
mask_feat_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SOLOv2, self).__init__(backbone, neck, bbox_head, mask_feat_head, train_cfg,
test_cfg, pretrained)
...@@ -3,8 +3,9 @@ from .fused_semantic_head import FusedSemanticHead ...@@ -3,8 +3,9 @@ from .fused_semantic_head import FusedSemanticHead
from .grid_head import GridHead from .grid_head import GridHead
from .htc_mask_head import HTCMaskHead from .htc_mask_head import HTCMaskHead
from .maskiou_head import MaskIoUHead from .maskiou_head import MaskIoUHead
from .mask_feat_head import MaskFeatHead
__all__ = [ __all__ = [
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
'MaskIoUHead' 'MaskIoUHead', 'MaskFeatHead'
] ]
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, normal_init
from ..registry import HEADS
from ..builder import build_loss
from ..utils import ConvModule
import torch
import numpy as np
@HEADS.register_module
class MaskFeatHead(nn.Module):
def __init__(self,
in_channels,
out_channels,
start_level,
end_level,
num_classes,
conv_cfg=None,
norm_cfg=None):
super(MaskFeatHead, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.start_level = start_level
self.end_level = end_level
assert start_level >= 0 and end_level >= start_level
self.num_classes = num_classes
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.convs_all_levels = nn.ModuleList()
for i in range(self.start_level, self.end_level + 1):
convs_per_level = nn.Sequential()
if i == 0:
one_conv = ConvModule(
self.in_channels,
self.out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=False)
convs_per_level.add_module('conv' + str(i), one_conv)
self.convs_all_levels.append(convs_per_level)
continue
for j in range(i):
if j == 0:
chn = self.in_channels+2 if i==3 else self.in_channels
one_conv = ConvModule(
chn,
self.out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=False)
convs_per_level.add_module('conv' + str(j), one_conv)
one_upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
convs_per_level.add_module(
'upsample' + str(j), one_upsample)
continue
one_conv = ConvModule(
self.out_channels,
self.out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=False)
convs_per_level.add_module('conv' + str(j), one_conv)
one_upsample = nn.Upsample(
scale_factor=2,
mode='bilinear',
align_corners=False)
convs_per_level.add_module('upsample' + str(j), one_upsample)
self.convs_all_levels.append(convs_per_level)
self.conv_pred = nn.Sequential(
ConvModule(
self.out_channels,
self.num_classes,
1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg),
)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.01)
def forward(self, inputs):
assert len(inputs) == (self.end_level - self.start_level + 1)
feature_add_all_level = self.convs_all_levels[0](inputs[0])
for i in range(1, len(inputs)):
input_p = inputs[i]
if i == 3:
input_feat = input_p
x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device)
y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([input_feat.shape[0], 1, -1, -1])
x = x.expand([input_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
input_p = torch.cat([input_p, coord_feat], 1)
feature_add_all_level += self.convs_all_levels[i](input_p)
feature_pred = self.conv_pred(feature_add_all_level)
return feature_pred
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment