Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
1a2ae138
Commit
1a2ae138
authored
Mar 31, 2020
by
WXinlong
Browse files
add two light-weight models
parent
28b3c6ea
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
735 additions
and
1 deletion
+735
-1
configs/solo/decoupled_solo_light_dcn_r50_fpn_8gpu_3x.py
configs/solo/decoupled_solo_light_dcn_r50_fpn_8gpu_3x.py
+136
-0
configs/solo/decoupled_solo_light_r50_fpn_8gpu_3x.py
configs/solo/decoupled_solo_light_r50_fpn_8gpu_3x.py
+129
-0
mmdet/models/anchor_heads/__init__.py
mmdet/models/anchor_heads/__init__.py
+2
-1
mmdet/models/anchor_heads/decoupled_solo_light_head.py
mmdet/models/anchor_heads/decoupled_solo_light_head.py
+468
-0
No files found.
configs/solo/decoupled_solo_light_dcn_r50_fpn_8gpu_3x.py
0 → 100644
View file @
1a2ae138
# model settings
model
=
dict
(
type
=
'SOLO'
,
pretrained
=
'torchvision://resnet50'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
# C2, C3, C4, C5
frozen_stages
=
1
,
style
=
'pytorch'
,
dcn
=
dict
(
type
=
'DCN'
,
deformable_groups
=
1
,
fallback_on_stride
=
False
),
stage_with_dcn
=
(
False
,
True
,
True
,
True
)),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'DecoupledSOLOLightHead'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
use_dcn_in_tower
=
True
,
type_dcn
=
'DCN'
,
seg_feat_channels
=
256
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
),
(
256
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
cate_down_pos
=
0
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
))
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
[(
852
,
512
),
(
852
,
480
),
(
852
,
448
),
(
852
,
416
),
(
852
,
384
),
(
852
,
352
)],
multiscale_mode
=
'value'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
852
,
512
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
1.0
/
3
,
step
=
[
27
,
33
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
36
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/decoupled_solo_light_dcn_release_r50_fpn_8gpu_3x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
configs/solo/decoupled_solo_light_r50_fpn_8gpu_3x.py
0 → 100644
View file @
1a2ae138
# model settings
model
=
dict
(
type
=
'SOLO'
,
pretrained
=
'torchvision://resnet50'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
# C2, C3, C4, C5
frozen_stages
=
1
,
style
=
'pytorch'
),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
0
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'DecoupledSOLOLightHead'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
seg_feat_channels
=
256
,
strides
=
[
8
,
8
,
16
,
32
,
32
],
scale_ranges
=
((
1
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
),
(
256
,
2048
)),
sigma
=
0.2
,
num_grids
=
[
40
,
36
,
24
,
16
,
12
],
cate_down_pos
=
0
,
loss_ins
=
dict
(
type
=
'DiceLoss'
,
use_sigmoid
=
True
,
loss_weight
=
3.0
),
loss_cate
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
))
# training and testing settings
train_cfg
=
dict
()
test_cfg
=
dict
(
nms_pre
=
500
,
score_thr
=
0.1
,
mask_thr
=
0.5
,
update_thr
=
0.05
,
kernel
=
'gaussian'
,
# gaussian/linear
sigma
=
2.0
,
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'LoadAnnotations'
,
with_bbox
=
True
,
with_mask
=
True
),
dict
(
type
=
'Resize'
,
img_scale
=
[(
852
,
512
),
(
852
,
480
),
(
852
,
448
),
(
852
,
416
),
(
852
,
384
),
(
852
,
352
)],
multiscale_mode
=
'value'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
,
flip_ratio
=
0.5
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'DefaultFormatBundle'
),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
,
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
]),
]
test_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
852
,
512
),
flip
=
False
,
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
type
=
'RandomFlip'
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
])
]
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
pipeline
=
train_pipeline
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
pipeline
=
test_pipeline
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
1.0
/
3
,
step
=
[
27
,
33
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
36
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/decoupled_solo_light_release_r50_fpn_8gpu_3x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
mmdet/models/anchor_heads/__init__.py
View file @
1a2ae138
...
...
@@ -13,10 +13,11 @@ from .rpn_head import RPNHead
from
.ssd_head
import
SSDHead
from
.solo_head
import
SOLOHead
from
.decoupled_solo_head
import
DecoupledSOLOHead
from
.decoupled_solo_light_head
import
DecoupledSOLOLightHead
__all__
=
[
'AnchorHead'
,
'GuidedAnchorHead'
,
'FeatureAdaption'
,
'RPNHead'
,
'GARPNHead'
,
'RetinaHead'
,
'RetinaSepBNHead'
,
'GARetinaHead'
,
'SSDHead'
,
'FCOSHead'
,
'RepPointsHead'
,
'FoveaHead'
,
'FreeAnchorRetinaHead'
,
'ATSSHead'
,
'SOLOHead'
,
'DecoupledSOLOHead'
'ATSSHead'
,
'SOLOHead'
,
'DecoupledSOLOHead'
,
'DecoupledSOLOLightHead'
]
mmdet/models/anchor_heads/decoupled_solo_light_head.py
0 → 100644
View file @
1a2ae138
import
mmcv
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.ops
import
DeformConv
,
roi_align
from
mmdet.core
import
multi_apply
,
bbox2roi
,
matrix_nms
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
,
ConvModule
INF
=
1e8
from
scipy
import
ndimage
def
points_nms
(
heat
,
kernel
=
2
):
# kernel must be 2
hmax
=
nn
.
functional
.
max_pool2d
(
heat
,
(
kernel
,
kernel
),
stride
=
1
,
padding
=
1
)
keep
=
(
hmax
[:,
:,
:
-
1
,
:
-
1
]
==
heat
).
float
()
return
heat
*
keep
def
dice_loss
(
input
,
target
):
input
=
input
.
contiguous
().
view
(
input
.
size
()[
0
],
-
1
)
target
=
target
.
contiguous
().
view
(
target
.
size
()[
0
],
-
1
).
float
()
a
=
torch
.
sum
(
input
*
target
,
1
)
b
=
torch
.
sum
(
input
*
input
,
1
)
+
0.001
c
=
torch
.
sum
(
target
*
target
,
1
)
+
0.001
d
=
(
2
*
a
)
/
(
b
+
c
)
return
1
-
d
@
HEADS
.
register_module
class
DecoupledSOLOLightHead
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
,
in_channels
,
seg_feat_channels
=
256
,
stacked_convs
=
4
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
base_edge_list
=
(
16
,
32
,
64
,
128
,
256
),
scale_ranges
=
((
8
,
32
),
(
16
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
)),
sigma
=
0.4
,
num_grids
=
None
,
cate_down_pos
=
0
,
loss_ins
=
None
,
loss_cate
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
,
use_dcn_in_tower
=
False
,
type_dcn
=
None
):
super
(
DecoupledSOLOLightHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
seg_num_grids
=
num_grids
self
.
cate_out_channels
=
self
.
num_classes
-
1
self
.
in_channels
=
in_channels
self
.
seg_feat_channels
=
seg_feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
sigma
=
sigma
self
.
cate_down_pos
=
cate_down_pos
self
.
base_edge_list
=
base_edge_list
self
.
scale_ranges
=
scale_ranges
self
.
loss_cate
=
build_loss
(
loss_cate
)
self
.
ins_loss_weight
=
loss_ins
[
'loss_weight'
]
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
type_dcn
=
type_dcn
self
.
_init_layers
()
def
_init_layers
(
self
):
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)
self
.
ins_convs
=
nn
.
ModuleList
()
self
.
cate_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
if
self
.
use_dcn_in_tower
and
i
==
self
.
stacked_convs
-
1
:
cfg_conv
=
dict
(
type
=
self
.
type_dcn
)
else
:
cfg_conv
=
self
.
conv_cfg
chn
=
self
.
in_channels
+
2
if
i
==
0
else
self
.
seg_feat_channels
self
.
ins_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
chn
=
self
.
in_channels
if
i
==
0
else
self
.
seg_feat_channels
self
.
cate_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
self
.
dsolo_ins_list_x
=
nn
.
ModuleList
()
self
.
dsolo_ins_list_y
=
nn
.
ModuleList
()
for
seg_num_grid
in
self
.
seg_num_grids
:
self
.
dsolo_ins_list_x
.
append
(
nn
.
Conv2d
(
self
.
seg_feat_channels
,
seg_num_grid
,
3
,
padding
=
1
))
self
.
dsolo_ins_list_y
.
append
(
nn
.
Conv2d
(
self
.
seg_feat_channels
,
seg_num_grid
,
3
,
padding
=
1
))
self
.
dsolo_cate
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
cate_out_channels
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
ins_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
cate_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_ins
=
bias_init_with_prob
(
0.01
)
for
m
in
self
.
dsolo_ins_list_x
:
normal_init
(
m
,
std
=
0.01
,
bias
=
bias_ins
)
for
m
in
self
.
dsolo_ins_list_y
:
normal_init
(
m
,
std
=
0.01
,
bias
=
bias_ins
)
bias_cate
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
dsolo_cate
,
std
=
0.01
,
bias
=
bias_cate
)
def
forward
(
self
,
feats
,
eval
=
False
):
new_feats
=
self
.
split_feats
(
feats
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
new_feats
]
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
2
,
featmap_sizes
[
0
][
1
]
*
2
)
ins_pred_x
,
ins_pred_y
,
cate_pred
=
multi_apply
(
self
.
forward_single
,
new_feats
,
list
(
range
(
len
(
self
.
seg_num_grids
))),
eval
=
eval
,
upsampled_size
=
upsampled_size
)
return
ins_pred_x
,
ins_pred_y
,
cate_pred
def
split_feats
(
self
,
feats
):
return
(
F
.
interpolate
(
feats
[
0
],
scale_factor
=
0.5
,
mode
=
'bilinear'
),
feats
[
1
],
feats
[
2
],
feats
[
3
],
F
.
interpolate
(
feats
[
4
],
size
=
feats
[
3
].
shape
[
-
2
:],
mode
=
'bilinear'
))
def
forward_single
(
self
,
x
,
idx
,
eval
=
False
,
upsampled_size
=
None
):
ins_feat
=
x
cate_feat
=
x
# ins branch
# concat coord
x_range
=
torch
.
linspace
(
-
1
,
1
,
ins_feat
.
shape
[
-
1
],
device
=
ins_feat
.
device
)
y_range
=
torch
.
linspace
(
-
1
,
1
,
ins_feat
.
shape
[
-
2
],
device
=
ins_feat
.
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
y
=
y
.
expand
([
ins_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
x
=
x
.
expand
([
ins_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
coord_feat
=
torch
.
cat
([
x
,
y
],
1
)
ins_feat
=
torch
.
cat
([
ins_feat
,
coord_feat
],
1
)
for
ins_layer
in
self
.
ins_convs
:
ins_feat
=
ins_layer
(
ins_feat
)
ins_feat
=
F
.
interpolate
(
ins_feat
,
scale_factor
=
2
,
mode
=
'bilinear'
)
ins_pred_x
=
self
.
dsolo_ins_list_x
[
idx
](
ins_feat
)
ins_pred_y
=
self
.
dsolo_ins_list_y
[
idx
](
ins_feat
)
# cate branch
for
i
,
cate_layer
in
enumerate
(
self
.
cate_convs
):
if
i
==
self
.
cate_down_pos
:
seg_num_grid
=
self
.
seg_num_grids
[
idx
]
cate_feat
=
F
.
interpolate
(
cate_feat
,
size
=
seg_num_grid
,
mode
=
'bilinear'
)
cate_feat
=
cate_layer
(
cate_feat
)
cate_pred
=
self
.
dsolo_cate
(
cate_feat
)
if
eval
:
ins_pred_x
=
F
.
interpolate
(
ins_pred_x
.
sigmoid
(),
size
=
upsampled_size
,
mode
=
'bilinear'
)
ins_pred_y
=
F
.
interpolate
(
ins_pred_y
.
sigmoid
(),
size
=
upsampled_size
,
mode
=
'bilinear'
)
cate_pred
=
points_nms
(
cate_pred
.
sigmoid
(),
kernel
=
2
).
permute
(
0
,
2
,
3
,
1
)
return
ins_pred_x
,
ins_pred_y
,
cate_pred
def
loss
(
self
,
ins_preds_x
,
ins_preds_y
,
cate_preds
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
ins_preds_x
]
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
ins_ind_label_list_xy
=
multi_apply
(
self
.
solo_target_single
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
featmap_sizes
=
featmap_sizes
)
# ins
ins_labels
=
[
torch
.
cat
([
ins_labels_level_img
[
ins_ind_labels_level_img
,
...]
for
ins_labels_level_img
,
ins_ind_labels_level_img
in
zip
(
ins_labels_level
,
ins_ind_labels_level
)],
0
)
for
ins_labels_level
,
ins_ind_labels_level
in
zip
(
zip
(
*
ins_label_list
),
zip
(
*
ins_ind_label_list
))]
ins_preds_x_final
=
[
torch
.
cat
([
ins_preds_level_img_x
[
ins_ind_labels_level_img
[:,
1
],
...]
for
ins_preds_level_img_x
,
ins_ind_labels_level_img
in
zip
(
ins_preds_level_x
,
ins_ind_labels_level
)],
0
)
for
ins_preds_level_x
,
ins_ind_labels_level
in
zip
(
ins_preds_x
,
zip
(
*
ins_ind_label_list_xy
))]
ins_preds_y_final
=
[
torch
.
cat
([
ins_preds_level_img_y
[
ins_ind_labels_level_img
[:,
0
],
...]
for
ins_preds_level_img_y
,
ins_ind_labels_level_img
in
zip
(
ins_preds_level_y
,
ins_ind_labels_level
)],
0
)
for
ins_preds_level_y
,
ins_ind_labels_level
in
zip
(
ins_preds_y
,
zip
(
*
ins_ind_label_list_xy
))]
num_ins
=
0.
# dice loss
loss_ins
=
[]
for
input_x
,
input_y
,
target
in
zip
(
ins_preds_x_final
,
ins_preds_y_final
,
ins_labels
):
mask_n
=
input_x
.
size
(
0
)
if
mask_n
==
0
:
continue
num_ins
+=
mask_n
input
=
(
input_x
.
sigmoid
())
*
(
input_y
.
sigmoid
())
loss_ins
.
append
(
dice_loss
(
input
,
target
))
loss_ins
=
torch
.
cat
(
loss_ins
).
mean
()
*
self
.
ins_loss_weight
# cate
cate_labels
=
[
torch
.
cat
([
cate_labels_level_img
.
flatten
()
for
cate_labels_level_img
in
cate_labels_level
])
for
cate_labels_level
in
zip
(
*
cate_label_list
)
]
flatten_cate_labels
=
torch
.
cat
(
cate_labels
)
cate_preds
=
[
cate_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cate_out_channels
)
for
cate_pred
in
cate_preds
]
flatten_cate_preds
=
torch
.
cat
(
cate_preds
)
loss_cate
=
self
.
loss_cate
(
flatten_cate_preds
,
flatten_cate_labels
,
avg_factor
=
num_ins
+
1
)
return
dict
(
loss_ins
=
loss_ins
,
loss_cate
=
loss_cate
)
def
solo_target_single
(
self
,
gt_bboxes_raw
,
gt_labels_raw
,
gt_masks_raw
,
featmap_sizes
=
None
):
device
=
gt_labels_raw
[
0
].
device
# ins
gt_areas
=
torch
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_label_list
=
[]
cate_label_list
=
[]
ins_ind_label_list
=
[]
ins_ind_label_list_xy
=
[]
for
(
lower_bound
,
upper_bound
),
stride
,
featmap_size
,
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
strides
,
featmap_sizes
,
self
.
seg_num_grids
):
ins_label
=
torch
.
zeros
([
num_grid
**
2
,
featmap_size
[
0
],
featmap_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
cate_label
=
torch
.
zeros
([
num_grid
,
num_grid
],
dtype
=
torch
.
int64
,
device
=
device
)
ins_ind_label
=
torch
.
zeros
([
num_grid
**
2
],
dtype
=
torch
.
bool
,
device
=
device
)
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
().
flatten
()
if
len
(
hit_indices
)
==
0
:
ins_label
=
torch
.
zeros
([
1
,
featmap_size
[
0
],
featmap_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label
=
torch
.
zeros
([
1
],
dtype
=
torch
.
bool
,
device
=
device
)
ins_ind_label_list
.
append
(
ins_ind_label
)
ins_ind_label_list_xy
.
append
(
cate_label
.
nonzero
())
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
.
cpu
().
numpy
(),
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
sigma
output_stride
=
stride
/
2
for
seg_mask
,
gt_label
,
half_h
,
half_w
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
):
if
seg_mask
.
sum
()
<
10
:
continue
# mass center
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
4
,
featmap_sizes
[
0
][
1
]
*
4
)
center_h
,
center_w
=
ndimage
.
measurements
.
center_of_mass
(
seg_mask
)
coord_w
=
int
((
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
((
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
# squared
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
# ins
seg_mask
=
mmcv
.
imrescale
(
seg_mask
,
scale
=
1.
/
output_stride
)
seg_mask
=
torch
.
Tensor
(
seg_mask
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
ins_label
[
label
,
:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_ind_label
[
label
]
=
True
ins_label
=
ins_label
[
ins_ind_label
]
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label
=
ins_ind_label
[
ins_ind_label
]
ins_ind_label_list
.
append
(
ins_ind_label
)
ins_ind_label_list_xy
.
append
(
cate_label
.
nonzero
())
return
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
ins_ind_label_list_xy
def
get_seg
(
self
,
seg_preds_x
,
seg_preds_y
,
cate_preds
,
img_metas
,
cfg
,
rescale
=
None
):
assert
len
(
seg_preds_x
)
==
len
(
cate_preds
)
num_levels
=
len
(
cate_preds
)
featmap_size
=
seg_preds_x
[
0
].
size
()[
-
2
:]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cate_pred_list
=
[
cate_preds
[
i
][
img_id
].
view
(
-
1
,
self
.
cate_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
seg_pred_list_x
=
[
seg_preds_x
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
seg_pred_list_y
=
[
seg_preds_y
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
ori_shape
=
img_metas
[
img_id
][
'ori_shape'
]
cate_pred_list
=
torch
.
cat
(
cate_pred_list
,
dim
=
0
)
seg_pred_list_x
=
torch
.
cat
(
seg_pred_list_x
,
dim
=
0
)
seg_pred_list_y
=
torch
.
cat
(
seg_pred_list_y
,
dim
=
0
)
result
=
self
.
get_seg_single
(
cate_pred_list
,
seg_pred_list_x
,
seg_pred_list_y
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
result
)
return
result_list
def
get_seg_single
(
self
,
cate_preds
,
seg_preds_x
,
seg_preds_y
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
=
False
,
debug
=
False
):
# overall info.
h
,
w
,
_
=
img_shape
upsampled_size_out
=
(
featmap_size
[
0
]
*
4
,
featmap_size
[
1
]
*
4
)
# trans trans_diff.
trans_size
=
torch
.
Tensor
(
self
.
seg_num_grids
).
pow
(
2
).
cumsum
(
0
).
long
()
trans_diff
=
torch
.
ones
(
trans_size
[
-
1
].
item
(),
device
=
cate_preds
.
device
).
long
()
num_grids
=
torch
.
ones
(
trans_size
[
-
1
].
item
(),
device
=
cate_preds
.
device
).
long
()
seg_size
=
torch
.
Tensor
(
self
.
seg_num_grids
).
cumsum
(
0
).
long
()
seg_diff
=
torch
.
ones
(
trans_size
[
-
1
].
item
(),
device
=
cate_preds
.
device
).
long
()
strides
=
torch
.
ones
(
trans_size
[
-
1
].
item
(),
device
=
cate_preds
.
device
)
n_stage
=
len
(
self
.
seg_num_grids
)
trans_diff
[:
trans_size
[
0
]]
*=
0
seg_diff
[:
trans_size
[
0
]]
*=
0
num_grids
[:
trans_size
[
0
]]
*=
self
.
seg_num_grids
[
0
]
strides
[:
trans_size
[
0
]]
*=
self
.
strides
[
0
]
for
ind_
in
range
(
1
,
n_stage
):
trans_diff
[
trans_size
[
ind_
-
1
]:
trans_size
[
ind_
]]
*=
trans_size
[
ind_
-
1
]
seg_diff
[
trans_size
[
ind_
-
1
]:
trans_size
[
ind_
]]
*=
seg_size
[
ind_
-
1
]
num_grids
[
trans_size
[
ind_
-
1
]:
trans_size
[
ind_
]]
*=
self
.
seg_num_grids
[
ind_
]
strides
[
trans_size
[
ind_
-
1
]:
trans_size
[
ind_
]]
*=
self
.
strides
[
ind_
]
# process.
inds
=
(
cate_preds
>
cfg
.
score_thr
)
cate_scores
=
cate_preds
[
inds
]
inds
=
inds
.
nonzero
()
trans_diff
=
torch
.
index_select
(
trans_diff
,
dim
=
0
,
index
=
inds
[:,
0
])
seg_diff
=
torch
.
index_select
(
seg_diff
,
dim
=
0
,
index
=
inds
[:,
0
])
num_grids
=
torch
.
index_select
(
num_grids
,
dim
=
0
,
index
=
inds
[:,
0
])
strides
=
torch
.
index_select
(
strides
,
dim
=
0
,
index
=
inds
[:,
0
])
y_inds
=
(
inds
[:,
0
]
-
trans_diff
)
//
num_grids
x_inds
=
(
inds
[:,
0
]
-
trans_diff
)
%
num_grids
y_inds
+=
seg_diff
x_inds
+=
seg_diff
cate_labels
=
inds
[:,
1
]
seg_masks_soft
=
seg_preds_x
[
x_inds
,
...]
*
seg_preds_y
[
y_inds
,
...]
seg_masks
=
seg_masks_soft
>
cfg
.
mask_thr
sum_masks
=
seg_masks
.
sum
((
1
,
2
)).
float
()
keep
=
sum_masks
>
strides
seg_masks_soft
=
seg_masks_soft
[
keep
,
...]
seg_masks
=
seg_masks
[
keep
,
...]
cate_scores
=
cate_scores
[
keep
]
sum_masks
=
sum_masks
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# mask scoring
seg_score
=
(
seg_masks_soft
*
seg_masks
.
float
()).
sum
((
1
,
2
))
/
sum_masks
cate_scores
*=
seg_score
if
len
(
cate_scores
)
==
0
:
return
None
# sort and keep top nms_pre
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
nms_pre
:
sort_inds
=
sort_inds
[:
cfg
.
nms_pre
]
seg_masks_soft
=
seg_masks_soft
[
sort_inds
,
:,
:]
seg_masks
=
seg_masks
[
sort_inds
,
:,
:]
cate_scores
=
cate_scores
[
sort_inds
]
sum_masks
=
sum_masks
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
# Matrix NMS
cate_scores
=
matrix_nms
(
seg_masks
,
cate_labels
,
cate_scores
,
kernel
=
cfg
.
kernel
,
sigma
=
cfg
.
sigma
,
sum_masks
=
sum_masks
)
keep
=
cate_scores
>=
cfg
.
update_thr
seg_masks_soft
=
seg_masks_soft
[
keep
,
:,
:]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# sort and keep top_k
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
max_per_img
:
sort_inds
=
sort_inds
[:
cfg
.
max_per_img
]
seg_masks_soft
=
seg_masks_soft
[
sort_inds
,
:,
:]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
seg_masks_soft
=
F
.
interpolate
(
seg_masks_soft
.
unsqueeze
(
0
),
size
=
upsampled_size_out
,
mode
=
'bilinear'
)[:,
:,
:
h
,
:
w
]
seg_masks
=
F
.
interpolate
(
seg_masks_soft
,
size
=
ori_shape
[:
2
],
mode
=
'bilinear'
).
squeeze
(
0
)
seg_masks
=
seg_masks
>
cfg
.
mask_thr
return
seg_masks
,
cate_labels
,
cate_scores
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment