Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
451933f7
Commit
451933f7
authored
May 04, 2020
by
WXinlong
Browse files
add solov2
parent
d5398a0d
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1341 additions
and
6 deletions
+1341
-6
configs/solov2/solov2_r101_dcn_fpn_8gpu_3x.py
configs/solov2/solov2_r101_dcn_fpn_8gpu_3x.py
+145
-0
configs/solov2/solov2_r101_fpn_8gpu_3x.py
configs/solov2/solov2_r101_fpn_8gpu_3x.py
+137
-0
configs/solov2/solov2_r50_fpn_8gpu_1x.py
configs/solov2/solov2_r50_fpn_8gpu_1x.py
+133
-0
configs/solov2/solov2_r50_fpn_8gpu_3x.py
configs/solov2/solov2_r50_fpn_8gpu_3x.py
+137
-0
configs/solov2/solov2_x101_dcn_fpn_8gpu_3x.py
configs/solov2/solov2_x101_dcn_fpn_8gpu_3x.py
+147
-0
mmdet/models/anchor_heads/__init__.py
mmdet/models/anchor_heads/__init__.py
+2
-1
mmdet/models/anchor_heads/solov2_head.py
mmdet/models/anchor_heads/solov2_head.py
+468
-0
mmdet/models/detectors/__init__.py
mmdet/models/detectors/__init__.py
+2
-1
mmdet/models/detectors/base.py
mmdet/models/detectors/base.py
+5
-0
mmdet/models/detectors/single_stage_ins.py
mmdet/models/detectors/single_stage_ins.py
+26
-2
mmdet/models/detectors/solo.py
mmdet/models/detectors/solo.py
+1
-1
mmdet/models/detectors/solov2.py
mmdet/models/detectors/solov2.py
+17
-0
mmdet/models/mask_heads/__init__.py
mmdet/models/mask_heads/__init__.py
+2
-1
mmdet/models/mask_heads/mask_feat_head.py
mmdet/models/mask_heads/mask_feat_head.py
+119
-0
No files found.
configs/solov2/solov2_r101_dcn_fpn_8gpu_3x.py
0 → 100644
View file @
451933f7
# model settings
model
=
dict
(
type
=
'SOLOv2'
,
pretrained
=
'torchvision://resnet101'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
101
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
# C2, C3, C4, C5
frozen_stages
=
1
,
style
=
'pytorch'
,
dcn
=
dict
(
type
=
'DCNv2'
,
deformable_groups
=
1
,
fallback_on_stride
=
False
),
stage_with_dcn
=
(
False
,
True
,
True
,
True
)),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'SOLOv2Head'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
use_dcn_in_tower
=
True
,
type_dcn
=
'DCNv2'
,
seg_feat_channels
=
512
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
96
),
(
48
,
192
),
(
96
,
384
),
(
192
,
768
),
(
384
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
ins_out_channels
=
256
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
)),
mask_feat_head
=
dict
(
type
=
'MaskFeatHead'
,
in_channels
=
256
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
256
,
conv_cfg
=
dict
(
type
=
'DCNv2'
),
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)),
)
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
[(
1333
,
800
),
(
1333
,
768
),
(
1333
,
736
),
(
1333
,
704
),
(
1333
,
672
),
(
1333
,
640
)],
multiscale_mode
=
'value'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
0.01
,
step
=
[
27
,
33
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
36
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/solov2_release_r101_dcn_fpn_8gpu_3x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
configs/solov2/solov2_r101_fpn_8gpu_3x.py
0 → 100644
View file @
451933f7
# model settings
model
=
dict
(
type
=
'SOLOv2'
,
pretrained
=
'torchvision://resnet101'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
101
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
# C2, C3, C4, C5
frozen_stages
=
1
,
style
=
'pytorch'
),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'SOLOv2Head'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
seg_feat_channels
=
512
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
96
),
(
48
,
192
),
(
96
,
384
),
(
192
,
768
),
(
384
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
ins_out_channels
=
256
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
)),
mask_feat_head
=
dict
(
type
=
'MaskFeatHead'
,
in_channels
=
256
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
256
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)),
)
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
[(
1333
,
800
),
(
1333
,
768
),
(
1333
,
736
),
(
1333
,
704
),
(
1333
,
672
),
(
1333
,
640
)],
multiscale_mode
=
'value'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
0.01
,
step
=
[
27
,
33
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
36
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/solov2_release_r101_fpn_8gpu_3x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
configs/solov2/solov2_r50_fpn_8gpu_1x.py
0 → 100644
View file @
451933f7
# model settings
model
=
dict
(
type
=
'SOLOv2'
,
pretrained
=
'torchvision://resnet50'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
# C2, C3, C4, C5
frozen_stages
=
1
,
style
=
'pytorch'
),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'SOLOv2Head'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
seg_feat_channels
=
512
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
96
),
(
48
,
192
),
(
96
,
384
),
(
192
,
768
),
(
384
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
ins_out_channels
=
256
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
)),
mask_feat_head
=
dict
(
type
=
'MaskFeatHead'
,
in_channels
=
256
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
256
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)),
)
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
(
1333
,
800
),
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
0.01
,
step
=
[
9
,
11
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
12
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/solov2_release_r50_fpn_8gpu_1x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
configs/solov2/solov2_r50_fpn_8gpu_3x.py
0 → 100644
View file @
451933f7
# model settings
model
=
dict
(
type
=
'SOLOv2'
,
pretrained
=
'torchvision://resnet50'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
# C2, C3, C4, C5
frozen_stages
=
1
,
style
=
'pytorch'
),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'SOLOv2Head'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
seg_feat_channels
=
512
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
96
),
(
48
,
192
),
(
96
,
384
),
(
192
,
768
),
(
384
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
ins_out_channels
=
256
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
)),
mask_feat_head
=
dict
(
type
=
'MaskFeatHead'
,
in_channels
=
256
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
256
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)),
)
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
[(
1333
,
800
),
(
1333
,
768
),
(
1333
,
736
),
(
1333
,
704
),
(
1333
,
672
),
(
1333
,
640
)],
multiscale_mode
=
'value'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
0.01
,
step
=
[
27
,
33
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
36
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/solov2_release_r50_fpn_8gpu_3x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
configs/solov2/solov2_x101_dcn_fpn_8gpu_3x.py
0 → 100644
View file @
451933f7
# model settings
model
=
dict
(
type
=
'SOLOv2'
,
pretrained
=
'open-mmlab://resnext101_64x4d'
,
backbone
=
dict
(
type
=
'ResNeXt'
,
depth
=
101
,
groups
=
64
,
base_width
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
style
=
'pytorch'
,
dcn
=
dict
(
type
=
'DCNv2'
,
deformable_groups
=
1
,
fallback_on_stride
=
False
),
stage_with_dcn
=
(
False
,
True
,
True
,
True
)),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'SOLOv2Head'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
use_dcn_in_tower
=
True
,
type_dcn
=
'DCNv2'
,
seg_feat_channels
=
512
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
96
),
(
48
,
192
),
(
96
,
384
),
(
192
,
768
),
(
384
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
ins_out_channels
=
256
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
)),
mask_feat_head
=
dict
(
type
=
'MaskFeatHead'
,
in_channels
=
256
,
out_channels
=
128
,
start_level
=
0
,
end_level
=
3
,
num_classes
=
256
,
conv_cfg
=
dict
(
type
=
'DCNv2'
),
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)),
)
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
[(
1333
,
800
),
(
1333
,
768
),
(
1333
,
736
),
(
1333
,
704
),
(
1333
,
672
),
(
1333
,
640
)],
multiscale_mode
=
'value'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
0.01
,
step
=
[
27
,
33
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
36
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/solov2_release_x101_dcn_fpn_8gpu_3x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
mmdet/models/anchor_heads/__init__.py
View file @
451933f7
...
@@ -12,6 +12,7 @@ from .retina_sepbn_head import RetinaSepBNHead
...
@@ -12,6 +12,7 @@ from .retina_sepbn_head import RetinaSepBNHead
from
.rpn_head
import
RPNHead
from
.rpn_head
import
RPNHead
from
.ssd_head
import
SSDHead
from
.ssd_head
import
SSDHead
from
.solo_head
import
SOLOHead
from
.solo_head
import
SOLOHead
from
.solov2_head
import
SOLOv2Head
from
.decoupled_solo_head
import
DecoupledSOLOHead
from
.decoupled_solo_head
import
DecoupledSOLOHead
from
.decoupled_solo_light_head
import
DecoupledSOLOLightHead
from
.decoupled_solo_light_head
import
DecoupledSOLOLightHead
...
@@ -19,5 +20,5 @@ __all__ = [
...
@@ -19,5 +20,5 @@ __all__ = [
'AnchorHead'
,
'GuidedAnchorHead'
,
'FeatureAdaption'
,
'RPNHead'
,
'AnchorHead'
,
'GuidedAnchorHead'
,
'FeatureAdaption'
,
'RPNHead'
,
'GARPNHead'
,
'RetinaHead'
,
'RetinaSepBNHead'
,
'GARetinaHead'
,
'SSDHead'
,
'GARPNHead'
,
'RetinaHead'
,
'RetinaSepBNHead'
,
'GARetinaHead'
,
'SSDHead'
,
'FCOSHead'
,
'RepPointsHead'
,
'FoveaHead'
,
'FreeAnchorRetinaHead'
,
'FCOSHead'
,
'RepPointsHead'
,
'FoveaHead'
,
'FreeAnchorRetinaHead'
,
'ATSSHead'
,
'SOLOHead'
,
'DecoupledSOLOHead'
,
'DecoupledSOLOLightHead'
'ATSSHead'
,
'SOLOHead'
,
'SOLOv2Head'
,
'DecoupledSOLOHead'
,
'DecoupledSOLOLightHead'
]
]
mmdet/models/anchor_heads/solov2_head.py
0 → 100644
View file @
451933f7
import
mmcv
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.ops
import
DeformConv
,
roi_align
from
mmdet.core
import
multi_apply
,
matrix_nms
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
,
ConvModule
INF
=
1e8
from
scipy
import
ndimage
def
points_nms
(
heat
,
kernel
=
2
):
# kernel must be 2
hmax
=
nn
.
functional
.
max_pool2d
(
heat
,
(
kernel
,
kernel
),
stride
=
1
,
padding
=
1
)
keep
=
(
hmax
[:,
:,
:
-
1
,
:
-
1
]
==
heat
).
float
()
return
heat
*
keep
def
dice_loss
(
input
,
target
):
input
=
input
.
contiguous
().
view
(
input
.
size
()[
0
],
-
1
)
target
=
target
.
contiguous
().
view
(
target
.
size
()[
0
],
-
1
).
float
()
a
=
torch
.
sum
(
input
*
target
,
1
)
b
=
torch
.
sum
(
input
*
input
,
1
)
+
0.001
c
=
torch
.
sum
(
target
*
target
,
1
)
+
0.001
d
=
(
2
*
a
)
/
(
b
+
c
)
return
1
-
d
@
HEADS
.
register_module
class
SOLOv2Head
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
,
in_channels
,
seg_feat_channels
=
256
,
stacked_convs
=
4
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
base_edge_list
=
(
16
,
32
,
64
,
128
,
256
),
scale_ranges
=
((
8
,
32
),
(
16
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
)),
sigma
=
0.2
,
num_grids
=
None
,
ins_out_channels
=
64
,
loss_ins
=
None
,
loss_cate
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
,
use_dcn_in_tower
=
False
,
type_dcn
=
None
):
super
(
SOLOv2Head
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
seg_num_grids
=
num_grids
self
.
cate_out_channels
=
self
.
num_classes
-
1
self
.
ins_out_channels
=
ins_out_channels
self
.
in_channels
=
in_channels
self
.
seg_feat_channels
=
seg_feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
sigma
=
sigma
self
.
stacked_convs
=
stacked_convs
self
.
kernel_out_channels
=
self
.
ins_out_channels
*
1
*
1
self
.
base_edge_list
=
base_edge_list
self
.
scale_ranges
=
scale_ranges
self
.
loss_cate
=
build_loss
(
loss_cate
)
self
.
ins_loss_weight
=
loss_ins
[
'loss_weight'
]
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
type_dcn
=
type_dcn
self
.
_init_layers
()
def
_init_layers
(
self
):
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)
self
.
cate_convs
=
nn
.
ModuleList
()
self
.
kernel_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
if
self
.
use_dcn_in_tower
:
cfg_conv
=
dict
(
type
=
self
.
type_dcn
)
else
:
cfg_conv
=
self
.
conv_cfg
chn
=
self
.
in_channels
+
2
if
i
==
0
else
self
.
seg_feat_channels
self
.
kernel_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
chn
=
self
.
in_channels
if
i
==
0
else
self
.
seg_feat_channels
self
.
cate_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
self
.
solo_cate
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
cate_out_channels
,
3
,
padding
=
1
)
self
.
solo_kernel
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
kernel_out_channels
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cate_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
kernel_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cate
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
solo_cate
,
std
=
0.01
,
bias
=
bias_cate
)
normal_init
(
self
.
solo_kernel
,
std
=
0.01
)
def
forward
(
self
,
feats
,
eval
=
False
):
new_feats
=
self
.
split_feats
(
feats
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
new_feats
]
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
2
,
featmap_sizes
[
0
][
1
]
*
2
)
cate_pred
,
kernel_pred
=
multi_apply
(
self
.
forward_single
,
new_feats
,
list
(
range
(
len
(
self
.
seg_num_grids
))),
eval
=
eval
,
upsampled_size
=
upsampled_size
)
return
cate_pred
,
kernel_pred
def
split_feats
(
self
,
feats
):
return
(
F
.
interpolate
(
feats
[
0
],
scale_factor
=
0.5
,
mode
=
'bilinear'
),
feats
[
1
],
feats
[
2
],
feats
[
3
],
F
.
interpolate
(
feats
[
4
],
size
=
feats
[
3
].
shape
[
-
2
:],
mode
=
'bilinear'
))
def
forward_single
(
self
,
x
,
idx
,
eval
=
False
,
upsampled_size
=
None
):
ins_kernel_feat
=
x
# ins branch
# concat coord
x_range
=
torch
.
linspace
(
-
1
,
1
,
ins_kernel_feat
.
shape
[
-
1
],
device
=
ins_kernel_feat
.
device
)
y_range
=
torch
.
linspace
(
-
1
,
1
,
ins_kernel_feat
.
shape
[
-
2
],
device
=
ins_kernel_feat
.
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
y
=
y
.
expand
([
ins_kernel_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
x
=
x
.
expand
([
ins_kernel_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
coord_feat
=
torch
.
cat
([
x
,
y
],
1
)
ins_kernel_feat
=
torch
.
cat
([
ins_kernel_feat
,
coord_feat
],
1
)
# kernel branch
kernel_feat
=
ins_kernel_feat
seg_num_grid
=
self
.
seg_num_grids
[
idx
]
kernel_feat
=
F
.
interpolate
(
kernel_feat
,
size
=
seg_num_grid
,
mode
=
'bilinear'
)
cate_feat
=
kernel_feat
[:,
:
-
2
,
:,
:]
kernel_feat
=
kernel_feat
.
contiguous
()
for
i
,
kernel_layer
in
enumerate
(
self
.
kernel_convs
):
kernel_feat
=
kernel_layer
(
kernel_feat
)
kernel_pred
=
self
.
solo_kernel
(
kernel_feat
)
# cate branch
cate_feat
=
cate_feat
.
contiguous
()
for
i
,
cate_layer
in
enumerate
(
self
.
cate_convs
):
cate_feat
=
cate_layer
(
cate_feat
)
cate_pred
=
self
.
solo_cate
(
cate_feat
)
if
eval
:
cate_pred
=
points_nms
(
cate_pred
.
sigmoid
(),
kernel
=
2
).
permute
(
0
,
2
,
3
,
1
)
return
cate_pred
,
kernel_pred
def
loss
(
self
,
cate_preds
,
kernel_preds
,
ins_pred
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
mask_feat_size
=
ins_pred
.
size
()[
-
2
:]
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
grid_order_list
=
multi_apply
(
self
.
solov2_target_single
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
mask_feat_size
=
mask_feat_size
)
# ins
ins_labels
=
[
torch
.
cat
([
ins_labels_level_img
for
ins_labels_level_img
in
ins_labels_level
],
0
)
for
ins_labels_level
in
zip
(
*
ins_label_list
)]
kernel_preds
=
[[
kernel_preds_level_img
.
view
(
kernel_preds_level_img
.
shape
[
0
],
-
1
)[:,
grid_orders_level_img
]
for
kernel_preds_level_img
,
grid_orders_level_img
in
zip
(
kernel_preds_level
,
grid_orders_level
)]
for
kernel_preds_level
,
grid_orders_level
in
zip
(
kernel_preds
,
zip
(
*
grid_order_list
))]
# generate masks
ins_pred
=
ins_pred
ins_pred_list
=
[]
for
b_kernel_pred
in
kernel_preds
:
b_mask_pred
=
[]
for
idx
,
kernel_pred
in
enumerate
(
b_kernel_pred
):
if
kernel_pred
.
size
()[
-
1
]
==
0
:
continue
cur_ins_pred
=
ins_pred
[
idx
,
...]
H
,
W
=
cur_ins_pred
.
shape
[
-
2
:]
N
,
I
=
kernel_pred
.
shape
cur_ins_pred
=
cur_ins_pred
.
unsqueeze
(
0
)
kernel_pred
=
kernel_pred
.
permute
(
1
,
0
).
view
(
I
,
-
1
,
1
,
1
)
cur_ins_pred
=
F
.
conv2d
(
cur_ins_pred
,
kernel_pred
,
stride
=
1
).
view
(
-
1
,
H
,
W
)
b_mask_pred
.
append
(
cur_ins_pred
)
if
len
(
b_mask_pred
)
==
0
:
b_mask_pred
=
None
else
:
b_mask_pred
=
torch
.
cat
(
b_mask_pred
,
0
)
ins_pred_list
.
append
(
b_mask_pred
)
ins_ind_labels
=
[
torch
.
cat
([
ins_ind_labels_level_img
.
flatten
()
for
ins_ind_labels_level_img
in
ins_ind_labels_level
])
for
ins_ind_labels_level
in
zip
(
*
ins_ind_label_list
)
]
flatten_ins_ind_labels
=
torch
.
cat
(
ins_ind_labels
)
num_ins
=
flatten_ins_ind_labels
.
sum
()
# dice loss
loss_ins
=
[]
for
input
,
target
in
zip
(
ins_pred_list
,
ins_labels
):
if
input
is
None
:
continue
input
=
torch
.
sigmoid
(
input
)
loss_ins
.
append
(
dice_loss
(
input
,
target
))
loss_ins
=
torch
.
cat
(
loss_ins
).
mean
()
loss_ins
=
loss_ins
*
self
.
ins_loss_weight
# cate
cate_labels
=
[
torch
.
cat
([
cate_labels_level_img
.
flatten
()
for
cate_labels_level_img
in
cate_labels_level
])
for
cate_labels_level
in
zip
(
*
cate_label_list
)
]
flatten_cate_labels
=
torch
.
cat
(
cate_labels
)
cate_preds
=
[
cate_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cate_out_channels
)
for
cate_pred
in
cate_preds
]
flatten_cate_preds
=
torch
.
cat
(
cate_preds
)
loss_cate
=
self
.
loss_cate
(
flatten_cate_preds
,
flatten_cate_labels
,
avg_factor
=
num_ins
+
1
)
return
dict
(
loss_ins
=
loss_ins
,
loss_cate
=
loss_cate
)
def
solov2_target_single
(
self
,
gt_bboxes_raw
,
gt_labels_raw
,
gt_masks_raw
,
mask_feat_size
):
device
=
gt_labels_raw
[
0
].
device
# ins
gt_areas
=
torch
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_label_list
=
[]
cate_label_list
=
[]
ins_ind_label_list
=
[]
grid_order_list
=
[]
for
(
lower_bound
,
upper_bound
),
stride
,
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
strides
,
self
.
seg_num_grids
):
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
().
flatten
()
num_ins
=
len
(
hit_indices
)
ins_label
=
[]
grid_order
=
[]
cate_label
=
torch
.
zeros
([
num_grid
,
num_grid
],
dtype
=
torch
.
int64
,
device
=
device
)
ins_ind_label
=
torch
.
zeros
([
num_grid
**
2
],
dtype
=
torch
.
bool
,
device
=
device
)
if
num_ins
==
0
:
ins_label
=
torch
.
zeros
([
0
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
grid_order_list
.
append
([])
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
.
cpu
().
numpy
(),
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
sigma
output_stride
=
4
for
seg_mask
,
gt_label
,
half_h
,
half_w
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
):
if
seg_mask
.
sum
()
==
0
:
continue
# mass center
upsampled_size
=
(
mask_feat_size
[
0
]
*
4
,
mask_feat_size
[
1
]
*
4
)
center_h
,
center_w
=
ndimage
.
measurements
.
center_of_mass
(
seg_mask
)
coord_w
=
int
((
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
((
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
seg_mask
=
mmcv
.
imrescale
(
seg_mask
,
scale
=
1.
/
output_stride
)
seg_mask
=
torch
.
Tensor
(
seg_mask
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
cur_ins_label
=
torch
.
zeros
([
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
cur_ins_label
[:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_label
.
append
(
cur_ins_label
)
ins_ind_label
[
label
]
=
True
grid_order
.
append
(
label
)
ins_label
=
torch
.
stack
(
ins_label
,
0
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
grid_order_list
.
append
(
grid_order
)
return
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
grid_order_list
def
get_seg
(
self
,
cate_preds
,
kernel_preds
,
seg_pred
,
img_metas
,
cfg
,
rescale
=
None
):
num_levels
=
len
(
cate_preds
)
featmap_size
=
seg_pred
.
size
()[
-
2
:]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cate_pred_list
=
[
cate_preds
[
i
][
img_id
].
view
(
-
1
,
self
.
cate_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
seg_pred_list
=
seg_pred
[
img_id
,
...].
unsqueeze
(
0
)
kernel_pred_list
=
[
kernel_preds
[
i
][
img_id
].
permute
(
1
,
2
,
0
).
view
(
-
1
,
self
.
kernel_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
ori_shape
=
img_metas
[
img_id
][
'ori_shape'
]
cate_pred_list
=
torch
.
cat
(
cate_pred_list
,
dim
=
0
)
kernel_pred_list
=
torch
.
cat
(
kernel_pred_list
,
dim
=
0
)
result
=
self
.
get_seg_single
(
cate_pred_list
,
seg_pred_list
,
kernel_pred_list
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
result
)
return
result_list
def
get_seg_single
(
self
,
cate_preds
,
seg_preds
,
kernel_preds
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
=
False
,
debug
=
False
):
assert
len
(
cate_preds
)
==
len
(
kernel_preds
)
# overall info.
h
,
w
,
_
=
img_shape
upsampled_size_out
=
(
featmap_size
[
0
]
*
4
,
featmap_size
[
1
]
*
4
)
# process.
inds
=
(
cate_preds
>
cfg
.
score_thr
)
cate_scores
=
cate_preds
[
inds
]
if
len
(
cate_scores
)
==
0
:
return
None
# cate_labels & kernel_preds
inds
=
inds
.
nonzero
()
cate_labels
=
inds
[:,
1
]
kernel_preds
=
kernel_preds
[
inds
[:,
0
]]
# trans vector.
size_trans
=
cate_labels
.
new_tensor
(
self
.
seg_num_grids
).
pow
(
2
).
cumsum
(
0
)
strides
=
kernel_preds
.
new_ones
(
size_trans
[
-
1
])
n_stage
=
len
(
self
.
seg_num_grids
)
strides
[:
size_trans
[
0
]]
*=
self
.
strides
[
0
]
for
ind_
in
range
(
1
,
n_stage
):
strides
[
size_trans
[
ind_
-
1
]:
size_trans
[
ind_
]]
*=
self
.
strides
[
ind_
]
strides
=
strides
[
inds
[:,
0
]]
# mask encoding.
I
,
N
=
kernel_preds
.
shape
kernel_preds
=
kernel_preds
.
view
(
I
,
N
,
1
,
1
)
seg_preds
=
F
.
conv2d
(
seg_preds
,
kernel_preds
,
stride
=
1
).
squeeze
(
0
).
sigmoid
()
# mask.
seg_masks
=
seg_preds
>
0.5
sum_masks
=
seg_masks
.
sum
((
1
,
2
)).
float
()
# filter.
keep
=
sum_masks
>
strides
if
keep
.
sum
()
==
0
:
return
None
seg_masks
=
seg_masks
[
keep
,
...]
seg_preds
=
seg_preds
[
keep
,
...]
sum_masks
=
sum_masks
[
keep
]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# mask scoring.
seg_scores
=
(
seg_preds
*
seg_masks
.
float
()).
sum
((
1
,
2
))
/
sum_masks
cate_scores
*=
seg_scores
# sort and keep top nms_pre
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
nms_pre
:
sort_inds
=
sort_inds
[:
cfg
.
nms_pre
]
seg_masks
=
seg_masks
[
sort_inds
,
:,
:]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
sum_masks
=
sum_masks
[
sort_inds
]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
# Matrix NMS
cate_scores
=
matrix_nms
(
seg_masks
,
cate_labels
,
cate_scores
,
kernel
=
cfg
.
kernel
,
sigma
=
cfg
.
sigma
,
sum_masks
=
sum_masks
)
# filter.
keep
=
cate_scores
>=
cfg
.
update_thr
if
keep
.
sum
()
==
0
:
return
None
seg_preds
=
seg_preds
[
keep
,
:,
:]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# sort and keep top_k
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
max_per_img
:
sort_inds
=
sort_inds
[:
cfg
.
max_per_img
]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
seg_preds
=
F
.
interpolate
(
seg_preds
.
unsqueeze
(
0
),
size
=
upsampled_size_out
,
mode
=
'bilinear'
)[:,
:,
:
h
,
:
w
]
seg_masks
=
F
.
interpolate
(
seg_preds
,
size
=
ori_shape
[:
2
],
mode
=
'bilinear'
).
squeeze
(
0
)
seg_masks
=
seg_masks
>
0.5
return
seg_masks
,
cate_labels
,
cate_scores
mmdet/models/detectors/__init__.py
View file @
451933f7
...
@@ -17,10 +17,11 @@ from .single_stage import SingleStageDetector
...
@@ -17,10 +17,11 @@ from .single_stage import SingleStageDetector
from
.single_stage_ins
import
SingleStageInsDetector
from
.single_stage_ins
import
SingleStageInsDetector
from
.two_stage
import
TwoStageDetector
from
.two_stage
import
TwoStageDetector
from
.solo
import
SOLO
from
.solo
import
SOLO
from
.solov2
import
SOLOv2
__all__
=
[
__all__
=
[
'ATSS'
,
'BaseDetector'
,
'SingleStageDetector'
,
'TwoStageDetector'
,
'RPN'
,
'ATSS'
,
'BaseDetector'
,
'SingleStageDetector'
,
'TwoStageDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'MaskRCNN'
,
'CascadeRCNN'
,
'HybridTaskCascade'
,
'FastRCNN'
,
'FasterRCNN'
,
'MaskRCNN'
,
'CascadeRCNN'
,
'HybridTaskCascade'
,
'DoubleHeadRCNN'
,
'RetinaNet'
,
'FCOS'
,
'GridRCNN'
,
'MaskScoringRCNN'
,
'DoubleHeadRCNN'
,
'RetinaNet'
,
'FCOS'
,
'GridRCNN'
,
'MaskScoringRCNN'
,
'RepPointsDetector'
,
'FOVEA'
,
'SingleStageInsDetector'
,
'SOLO'
'RepPointsDetector'
,
'FOVEA'
,
'SingleStageInsDetector'
,
'SOLO'
,
'SOLOv2'
]
]
mmdet/models/detectors/base.py
View file @
451933f7
...
@@ -20,6 +20,11 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
...
@@ -20,6 +20,11 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
def
with_neck
(
self
):
def
with_neck
(
self
):
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
return
hasattr
(
self
,
'neck'
)
and
self
.
neck
is
not
None
@
property
def
with_mask_feat_head
(
self
):
return
hasattr
(
self
,
'mask_feat_head'
)
and
\
self
.
mask_feat_head
is
not
None
@
property
@
property
def
with_shared_head
(
self
):
def
with_shared_head
(
self
):
return
hasattr
(
self
,
'shared_head'
)
and
self
.
shared_head
is
not
None
return
hasattr
(
self
,
'shared_head'
)
and
self
.
shared_head
is
not
None
...
...
mmdet/models/detectors/single_stage_ins.py
View file @
451933f7
...
@@ -13,6 +13,7 @@ class SingleStageInsDetector(BaseDetector):
...
@@ -13,6 +13,7 @@ class SingleStageInsDetector(BaseDetector):
backbone
,
backbone
,
neck
=
None
,
neck
=
None
,
bbox_head
=
None
,
bbox_head
=
None
,
mask_feat_head
=
None
,
train_cfg
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
pretrained
=
None
):
...
@@ -20,6 +21,9 @@ class SingleStageInsDetector(BaseDetector):
...
@@ -20,6 +21,9 @@ class SingleStageInsDetector(BaseDetector):
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
if
neck
is
not
None
:
if
neck
is
not
None
:
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
neck
=
builder
.
build_neck
(
neck
)
if
mask_feat_head
is
not
None
:
self
.
mask_feat_head
=
builder
.
build_head
(
mask_feat_head
)
self
.
bbox_head
=
builder
.
build_head
(
bbox_head
)
self
.
bbox_head
=
builder
.
build_head
(
bbox_head
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
...
@@ -34,6 +38,12 @@ class SingleStageInsDetector(BaseDetector):
...
@@ -34,6 +38,12 @@ class SingleStageInsDetector(BaseDetector):
m
.
init_weights
()
m
.
init_weights
()
else
:
else
:
self
.
neck
.
init_weights
()
self
.
neck
.
init_weights
()
if
self
.
with_mask_feat_head
:
if
isinstance
(
self
.
mask_feat_head
,
nn
.
Sequential
):
for
m
in
self
.
mask_feat_head
:
m
.
init_weights
()
else
:
self
.
mask_feat_head
.
init_weights
()
self
.
bbox_head
.
init_weights
()
self
.
bbox_head
.
init_weights
()
def
extract_feat
(
self
,
img
):
def
extract_feat
(
self
,
img
):
...
@@ -56,7 +66,14 @@ class SingleStageInsDetector(BaseDetector):
...
@@ -56,7 +66,14 @@ class SingleStageInsDetector(BaseDetector):
gt_masks
=
None
):
gt_masks
=
None
):
x
=
self
.
extract_feat
(
img
)
x
=
self
.
extract_feat
(
img
)
outs
=
self
.
bbox_head
(
x
)
outs
=
self
.
bbox_head
(
x
)
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_labels
,
gt_masks
,
img_metas
,
self
.
train_cfg
)
if
self
.
with_mask_feat_head
:
mask_feat_pred
=
self
.
mask_feat_head
(
x
[
self
.
mask_feat_head
.
start_level
:
self
.
mask_feat_head
.
end_level
+
1
])
loss_inputs
=
outs
+
(
mask_feat_pred
,
gt_bboxes
,
gt_labels
,
gt_masks
,
img_metas
,
self
.
train_cfg
)
else
:
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_labels
,
gt_masks
,
img_metas
,
self
.
train_cfg
)
losses
=
self
.
bbox_head
.
loss
(
losses
=
self
.
bbox_head
.
loss
(
*
loss_inputs
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
*
loss_inputs
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
return
losses
return
losses
...
@@ -64,7 +81,14 @@ class SingleStageInsDetector(BaseDetector):
...
@@ -64,7 +81,14 @@ class SingleStageInsDetector(BaseDetector):
def
simple_test
(
self
,
img
,
img_meta
,
rescale
=
False
):
def
simple_test
(
self
,
img
,
img_meta
,
rescale
=
False
):
x
=
self
.
extract_feat
(
img
)
x
=
self
.
extract_feat
(
img
)
outs
=
self
.
bbox_head
(
x
,
eval
=
True
)
outs
=
self
.
bbox_head
(
x
,
eval
=
True
)
seg_inputs
=
outs
+
(
img_meta
,
self
.
test_cfg
,
rescale
)
if
self
.
with_mask_feat_head
:
mask_feat_pred
=
self
.
mask_feat_head
(
x
[
self
.
mask_feat_head
.
start_level
:
self
.
mask_feat_head
.
end_level
+
1
])
seg_inputs
=
outs
+
(
mask_feat_pred
,
img_meta
,
self
.
test_cfg
,
rescale
)
else
:
seg_inputs
=
outs
+
(
img_meta
,
self
.
test_cfg
,
rescale
)
seg_result
=
self
.
bbox_head
.
get_seg
(
*
seg_inputs
)
seg_result
=
self
.
bbox_head
.
get_seg
(
*
seg_inputs
)
return
seg_result
return
seg_result
...
...
mmdet/models/detectors/solo.py
View file @
451933f7
...
@@ -12,5 +12,5 @@ class SOLO(SingleStageInsDetector):
...
@@ -12,5 +12,5 @@ class SOLO(SingleStageInsDetector):
train_cfg
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
pretrained
=
None
):
super
(
SOLO
,
self
).
__init__
(
backbone
,
neck
,
bbox_head
,
train_cfg
,
super
(
SOLO
,
self
).
__init__
(
backbone
,
neck
,
bbox_head
,
None
,
train_cfg
,
test_cfg
,
pretrained
)
test_cfg
,
pretrained
)
mmdet/models/detectors/solov2.py
0 → 100644
View file @
451933f7
from
.single_stage_ins
import
SingleStageInsDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
SOLOv2
(
SingleStageInsDetector
):
def
__init__
(
self
,
backbone
,
neck
,
bbox_head
,
mask_feat_head
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
super
(
SOLOv2
,
self
).
__init__
(
backbone
,
neck
,
bbox_head
,
mask_feat_head
,
train_cfg
,
test_cfg
,
pretrained
)
mmdet/models/mask_heads/__init__.py
View file @
451933f7
...
@@ -3,8 +3,9 @@ from .fused_semantic_head import FusedSemanticHead
...
@@ -3,8 +3,9 @@ from .fused_semantic_head import FusedSemanticHead
from
.grid_head
import
GridHead
from
.grid_head
import
GridHead
from
.htc_mask_head
import
HTCMaskHead
from
.htc_mask_head
import
HTCMaskHead
from
.maskiou_head
import
MaskIoUHead
from
.maskiou_head
import
MaskIoUHead
from
.mask_feat_head
import
MaskFeatHead
__all__
=
[
__all__
=
[
'FCNMaskHead'
,
'HTCMaskHead'
,
'FusedSemanticHead'
,
'GridHead'
,
'FCNMaskHead'
,
'HTCMaskHead'
,
'FusedSemanticHead'
,
'GridHead'
,
'MaskIoUHead'
'MaskIoUHead'
,
'MaskFeatHead'
]
]
mmdet/models/mask_heads/mask_feat_head.py
0 → 100644
View file @
451933f7
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
normal_init
from
..registry
import
HEADS
from
..builder
import
build_loss
from
..utils
import
ConvModule
import
torch
import
numpy
as
np
@
HEADS
.
register_module
class
MaskFeatHead
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
start_level
,
end_level
,
num_classes
,
conv_cfg
=
None
,
norm_cfg
=
None
):
super
(
MaskFeatHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
start_level
=
start_level
self
.
end_level
=
end_level
assert
start_level
>=
0
and
end_level
>=
start_level
self
.
num_classes
=
num_classes
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
convs_all_levels
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
start_level
,
self
.
end_level
+
1
):
convs_per_level
=
nn
.
Sequential
()
if
i
==
0
:
one_conv
=
ConvModule
(
self
.
in_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
False
)
convs_per_level
.
add_module
(
'conv'
+
str
(
i
),
one_conv
)
self
.
convs_all_levels
.
append
(
convs_per_level
)
continue
for
j
in
range
(
i
):
if
j
==
0
:
chn
=
self
.
in_channels
+
2
if
i
==
3
else
self
.
in_channels
one_conv
=
ConvModule
(
chn
,
self
.
out_channels
,
3
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
False
)
convs_per_level
.
add_module
(
'conv'
+
str
(
j
),
one_conv
)
one_upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
convs_per_level
.
add_module
(
'upsample'
+
str
(
j
),
one_upsample
)
continue
one_conv
=
ConvModule
(
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
False
)
convs_per_level
.
add_module
(
'conv'
+
str
(
j
),
one_conv
)
one_upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
convs_per_level
.
add_module
(
'upsample'
+
str
(
j
),
one_upsample
)
self
.
convs_all_levels
.
append
(
convs_per_level
)
self
.
conv_pred
=
nn
.
Sequential
(
ConvModule
(
self
.
out_channels
,
self
.
num_classes
,
1
,
padding
=
0
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
),
)
def
init_weights
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
normal_init
(
m
,
std
=
0.01
)
def
forward
(
self
,
inputs
):
assert
len
(
inputs
)
==
(
self
.
end_level
-
self
.
start_level
+
1
)
feature_add_all_level
=
self
.
convs_all_levels
[
0
](
inputs
[
0
])
for
i
in
range
(
1
,
len
(
inputs
)):
input_p
=
inputs
[
i
]
if
i
==
3
:
input_feat
=
input_p
x_range
=
torch
.
linspace
(
-
1
,
1
,
input_feat
.
shape
[
-
1
],
device
=
input_feat
.
device
)
y_range
=
torch
.
linspace
(
-
1
,
1
,
input_feat
.
shape
[
-
2
],
device
=
input_feat
.
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
y
=
y
.
expand
([
input_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
x
=
x
.
expand
([
input_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
coord_feat
=
torch
.
cat
([
x
,
y
],
1
)
input_p
=
torch
.
cat
([
input_p
,
coord_feat
],
1
)
feature_add_all_level
+=
self
.
convs_all_levels
[
i
](
input_p
)
feature_pred
=
self
.
conv_pred
(
feature_add_all_level
)
return
feature_pred
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment