Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
57f6da5c
Commit
57f6da5c
authored
Nov 20, 2025
by
bailuo
Browse files
readme
parents
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5773 additions
and
0 deletions
+5773
-0
mmdet/models/anchor_heads/fcos_head.py
mmdet/models/anchor_heads/fcos_head.py
+408
-0
mmdet/models/anchor_heads/fovea_head.py
mmdet/models/anchor_heads/fovea_head.py
+387
-0
mmdet/models/anchor_heads/free_anchor_retina_head.py
mmdet/models/anchor_heads/free_anchor_retina_head.py
+188
-0
mmdet/models/anchor_heads/ga_retina_head.py
mmdet/models/anchor_heads/ga_retina_head.py
+107
-0
mmdet/models/anchor_heads/ga_rpn_head.py
mmdet/models/anchor_heads/ga_rpn_head.py
+127
-0
mmdet/models/anchor_heads/guided_anchor_head.py
mmdet/models/anchor_heads/guided_anchor_head.py
+621
-0
mmdet/models/anchor_heads/reppoints_head.py
mmdet/models/anchor_heads/reppoints_head.py
+596
-0
mmdet/models/anchor_heads/retina_head.py
mmdet/models/anchor_heads/retina_head.py
+103
-0
mmdet/models/anchor_heads/retina_sepbn_head.py
mmdet/models/anchor_heads/retina_sepbn_head.py
+105
-0
mmdet/models/anchor_heads/rpn_head.py
mmdet/models/anchor_heads/rpn_head.py
+104
-0
mmdet/models/anchor_heads/solo_head.py
mmdet/models/anchor_heads/solo_head.py
+433
-0
mmdet/models/anchor_heads/solov2_head.py
mmdet/models/anchor_heads/solov2_head.py
+483
-0
mmdet/models/anchor_heads/solov2_light_head.py
mmdet/models/anchor_heads/solov2_light_head.py
+482
-0
mmdet/models/anchor_heads/ssd_head.py
mmdet/models/anchor_heads/ssd_head.py
+201
-0
mmdet/models/backbones/__init__.py
mmdet/models/backbones/__init__.py
+6
-0
mmdet/models/backbones/hrnet.py
mmdet/models/backbones/hrnet.py
+524
-0
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+516
-0
mmdet/models/backbones/resnext.py
mmdet/models/backbones/resnext.py
+222
-0
mmdet/models/backbones/ssd_vgg.py
mmdet/models/backbones/ssd_vgg.py
+153
-0
mmdet/models/bbox_heads/__init__.py
mmdet/models/bbox_heads/__init__.py
+7
-0
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
mmdet/models/anchor_heads/fcos_head.py
0 → 100644
View file @
57f6da5c
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
distance2bbox
,
force_fp32
,
multi_apply
,
multiclass_nms
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
ConvModule
,
Scale
,
bias_init_with_prob
INF
=
1e8
@
HEADS
.
register_module
class
FCOSHead
(
nn
.
Module
):
"""
Fully Convolutional One-Stage Object Detection head from [1]_.
The FCOS head does not use anchor boxes. Instead bounding boxes are
predicted at each pixel and a centerness measure is used to supress
low-quality predictions.
References:
.. [1] https://arxiv.org/abs/1904.01355
Example:
>>> self = FCOSHead(11, 7)
>>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
>>> cls_score, bbox_pred, centerness = self.forward(feats)
>>> assert len(cls_score) == len(self.scales)
"""
def
__init__
(
self
,
num_classes
,
in_channels
,
feat_channels
=
256
,
stacked_convs
=
4
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
regress_ranges
=
((
-
1
,
64
),
(
64
,
128
),
(
128
,
256
),
(
256
,
512
),
(
512
,
INF
)),
loss_cls
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'IoULoss'
,
loss_weight
=
1.0
),
loss_centerness
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
True
,
loss_weight
=
1.0
),
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)):
super
(
FCOSHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
cls_out_channels
=
num_classes
-
1
self
.
in_channels
=
in_channels
self
.
feat_channels
=
feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
regress_ranges
=
regress_ranges
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_bbox
=
build_loss
(
loss_bbox
)
self
.
loss_centerness
=
build_loss
(
loss_centerness
)
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
fp16_enabled
=
False
self
.
_init_layers
()
def
_init_layers
(
self
):
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
cls_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
self
.
norm_cfg
is
None
))
self
.
reg_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
self
.
norm_cfg
is
None
))
self
.
fcos_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
cls_out_channels
,
3
,
padding
=
1
)
self
.
fcos_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
4
,
3
,
padding
=
1
)
self
.
fcos_centerness
=
nn
.
Conv2d
(
self
.
feat_channels
,
1
,
3
,
padding
=
1
)
self
.
scales
=
nn
.
ModuleList
([
Scale
(
1.0
)
for
_
in
self
.
strides
])
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
fcos_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
fcos_reg
,
std
=
0.01
)
normal_init
(
self
.
fcos_centerness
,
std
=
0.01
)
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
,
self
.
scales
)
def
forward_single
(
self
,
x
,
scale
):
cls_feat
=
x
reg_feat
=
x
for
cls_layer
in
self
.
cls_convs
:
cls_feat
=
cls_layer
(
cls_feat
)
cls_score
=
self
.
fcos_cls
(
cls_feat
)
centerness
=
self
.
fcos_centerness
(
cls_feat
)
for
reg_layer
in
self
.
reg_convs
:
reg_feat
=
reg_layer
(
reg_feat
)
# scale the bbox_pred of different level
# float to avoid overflow when enabling FP16
bbox_pred
=
scale
(
self
.
fcos_reg
(
reg_feat
)).
float
().
exp
()
return
cls_score
,
bbox_pred
,
centerness
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
,
'centernesses'
))
def
loss
(
self
,
cls_scores
,
bbox_preds
,
centernesses
,
gt_bboxes
,
gt_labels
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
centernesses
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
all_level_points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
bbox_preds
[
0
].
device
)
labels
,
bbox_targets
=
self
.
fcos_target
(
all_level_points
,
gt_bboxes
,
gt_labels
)
num_imgs
=
cls_scores
[
0
].
size
(
0
)
# flatten cls_scores, bbox_preds and centerness
flatten_cls_scores
=
[
cls_score
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cls_out_channels
)
for
cls_score
in
cls_scores
]
flatten_bbox_preds
=
[
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
4
)
for
bbox_pred
in
bbox_preds
]
flatten_centerness
=
[
centerness
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
)
for
centerness
in
centernesses
]
flatten_cls_scores
=
torch
.
cat
(
flatten_cls_scores
)
flatten_bbox_preds
=
torch
.
cat
(
flatten_bbox_preds
)
flatten_centerness
=
torch
.
cat
(
flatten_centerness
)
flatten_labels
=
torch
.
cat
(
labels
)
flatten_bbox_targets
=
torch
.
cat
(
bbox_targets
)
# repeat points to align with bbox_preds
flatten_points
=
torch
.
cat
(
[
points
.
repeat
(
num_imgs
,
1
)
for
points
in
all_level_points
])
pos_inds
=
flatten_labels
.
nonzero
().
reshape
(
-
1
)
num_pos
=
len
(
pos_inds
)
loss_cls
=
self
.
loss_cls
(
flatten_cls_scores
,
flatten_labels
,
avg_factor
=
num_pos
+
num_imgs
)
# avoid num_pos is 0
pos_bbox_preds
=
flatten_bbox_preds
[
pos_inds
]
pos_centerness
=
flatten_centerness
[
pos_inds
]
if
num_pos
>
0
:
pos_bbox_targets
=
flatten_bbox_targets
[
pos_inds
]
pos_centerness_targets
=
self
.
centerness_target
(
pos_bbox_targets
)
pos_points
=
flatten_points
[
pos_inds
]
pos_decoded_bbox_preds
=
distance2bbox
(
pos_points
,
pos_bbox_preds
)
pos_decoded_target_preds
=
distance2bbox
(
pos_points
,
pos_bbox_targets
)
# centerness weighted iou loss
loss_bbox
=
self
.
loss_bbox
(
pos_decoded_bbox_preds
,
pos_decoded_target_preds
,
weight
=
pos_centerness_targets
,
avg_factor
=
pos_centerness_targets
.
sum
())
loss_centerness
=
self
.
loss_centerness
(
pos_centerness
,
pos_centerness_targets
)
else
:
loss_bbox
=
pos_bbox_preds
.
sum
()
loss_centerness
=
pos_centerness
.
sum
()
return
dict
(
loss_cls
=
loss_cls
,
loss_bbox
=
loss_bbox
,
loss_centerness
=
loss_centerness
)
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
,
'centernesses'
))
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
centernesses
,
img_metas
,
cfg
,
rescale
=
None
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
num_levels
=
len
(
cls_scores
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
mlvl_points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
bbox_preds
[
0
].
device
)
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
centerness_pred_list
=
[
centernesses
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
det_bboxes
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
centerness_pred_list
,
mlvl_points
,
img_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
det_bboxes
)
return
result_list
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
centernesses
,
mlvl_points
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_points
)
mlvl_bboxes
=
[]
mlvl_scores
=
[]
mlvl_centerness
=
[]
for
cls_score
,
bbox_pred
,
centerness
,
points
in
zip
(
cls_scores
,
bbox_preds
,
centernesses
,
mlvl_points
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
scores
=
cls_score
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
cls_out_channels
).
sigmoid
()
centerness
=
centerness
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
).
sigmoid
()
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)
nms_pre
=
cfg
.
get
(
'nms_pre'
,
-
1
)
if
nms_pre
>
0
and
scores
.
shape
[
0
]
>
nms_pre
:
max_scores
,
_
=
(
scores
*
centerness
[:,
None
]).
max
(
dim
=
1
)
_
,
topk_inds
=
max_scores
.
topk
(
nms_pre
)
points
=
points
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
centerness
=
centerness
[
topk_inds
]
bboxes
=
distance2bbox
(
points
,
bbox_pred
,
max_shape
=
img_shape
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
mlvl_centerness
.
append
(
centerness
)
mlvl_bboxes
=
torch
.
cat
(
mlvl_bboxes
)
if
rescale
:
mlvl_bboxes
/=
mlvl_bboxes
.
new_tensor
(
scale_factor
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
mlvl_scores
=
torch
.
cat
([
padding
,
mlvl_scores
],
dim
=
1
)
mlvl_centerness
=
torch
.
cat
(
mlvl_centerness
)
det_bboxes
,
det_labels
=
multiclass_nms
(
mlvl_bboxes
,
mlvl_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
,
score_factors
=
mlvl_centerness
)
return
det_bboxes
,
det_labels
def
get_points
(
self
,
featmap_sizes
,
dtype
,
device
):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
mlvl_points
=
[]
for
i
in
range
(
len
(
featmap_sizes
)):
mlvl_points
.
append
(
self
.
get_points_single
(
featmap_sizes
[
i
],
self
.
strides
[
i
],
dtype
,
device
))
return
mlvl_points
def
get_points_single
(
self
,
featmap_size
,
stride
,
dtype
,
device
):
h
,
w
=
featmap_size
x_range
=
torch
.
arange
(
0
,
w
*
stride
,
stride
,
dtype
=
dtype
,
device
=
device
)
y_range
=
torch
.
arange
(
0
,
h
*
stride
,
stride
,
dtype
=
dtype
,
device
=
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
points
=
torch
.
stack
(
(
x
.
reshape
(
-
1
),
y
.
reshape
(
-
1
)),
dim
=-
1
)
+
stride
//
2
return
points
def
fcos_target
(
self
,
points
,
gt_bboxes_list
,
gt_labels_list
):
assert
len
(
points
)
==
len
(
self
.
regress_ranges
)
num_levels
=
len
(
points
)
# expand regress ranges to align with points
expanded_regress_ranges
=
[
points
[
i
].
new_tensor
(
self
.
regress_ranges
[
i
])[
None
].
expand_as
(
points
[
i
])
for
i
in
range
(
num_levels
)
]
# concat all levels points and regress ranges
concat_regress_ranges
=
torch
.
cat
(
expanded_regress_ranges
,
dim
=
0
)
concat_points
=
torch
.
cat
(
points
,
dim
=
0
)
# get labels and bbox_targets of each image
labels_list
,
bbox_targets_list
=
multi_apply
(
self
.
fcos_target_single
,
gt_bboxes_list
,
gt_labels_list
,
points
=
concat_points
,
regress_ranges
=
concat_regress_ranges
)
# split to per img, per level
num_points
=
[
center
.
size
(
0
)
for
center
in
points
]
labels_list
=
[
labels
.
split
(
num_points
,
0
)
for
labels
in
labels_list
]
bbox_targets_list
=
[
bbox_targets
.
split
(
num_points
,
0
)
for
bbox_targets
in
bbox_targets_list
]
# concat per level image
concat_lvl_labels
=
[]
concat_lvl_bbox_targets
=
[]
for
i
in
range
(
num_levels
):
concat_lvl_labels
.
append
(
torch
.
cat
([
labels
[
i
]
for
labels
in
labels_list
]))
concat_lvl_bbox_targets
.
append
(
torch
.
cat
(
[
bbox_targets
[
i
]
for
bbox_targets
in
bbox_targets_list
]))
return
concat_lvl_labels
,
concat_lvl_bbox_targets
def
fcos_target_single
(
self
,
gt_bboxes
,
gt_labels
,
points
,
regress_ranges
):
num_points
=
points
.
size
(
0
)
num_gts
=
gt_labels
.
size
(
0
)
if
num_gts
==
0
:
return
gt_labels
.
new_zeros
(
num_points
),
\
gt_bboxes
.
new_zeros
((
num_points
,
4
))
areas
=
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
]
+
1
)
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
]
+
1
)
# TODO: figure out why these two are different
# areas = areas[None].expand(num_points, num_gts)
areas
=
areas
[
None
].
repeat
(
num_points
,
1
)
regress_ranges
=
regress_ranges
[:,
None
,
:].
expand
(
num_points
,
num_gts
,
2
)
gt_bboxes
=
gt_bboxes
[
None
].
expand
(
num_points
,
num_gts
,
4
)
xs
,
ys
=
points
[:,
0
],
points
[:,
1
]
xs
=
xs
[:,
None
].
expand
(
num_points
,
num_gts
)
ys
=
ys
[:,
None
].
expand
(
num_points
,
num_gts
)
left
=
xs
-
gt_bboxes
[...,
0
]
right
=
gt_bboxes
[...,
2
]
-
xs
top
=
ys
-
gt_bboxes
[...,
1
]
bottom
=
gt_bboxes
[...,
3
]
-
ys
bbox_targets
=
torch
.
stack
((
left
,
top
,
right
,
bottom
),
-
1
)
# condition1: inside a gt bbox
inside_gt_bbox_mask
=
bbox_targets
.
min
(
-
1
)[
0
]
>
0
# condition2: limit the regression range for each location
max_regress_distance
=
bbox_targets
.
max
(
-
1
)[
0
]
inside_regress_range
=
(
max_regress_distance
>=
regress_ranges
[...,
0
])
&
(
max_regress_distance
<=
regress_ranges
[...,
1
])
# if there are still more than one objects for a location,
# we choose the one with minimal area
areas
[
inside_gt_bbox_mask
==
0
]
=
INF
areas
[
inside_regress_range
==
0
]
=
INF
min_area
,
min_area_inds
=
areas
.
min
(
dim
=
1
)
labels
=
gt_labels
[
min_area_inds
]
labels
[
min_area
==
INF
]
=
0
bbox_targets
=
bbox_targets
[
range
(
num_points
),
min_area_inds
]
return
labels
,
bbox_targets
def
centerness_target
(
self
,
pos_bbox_targets
):
# only calculate pos centerness targets, otherwise there may be nan
left_right
=
pos_bbox_targets
[:,
[
0
,
2
]]
top_bottom
=
pos_bbox_targets
[:,
[
1
,
3
]]
centerness_targets
=
(
left_right
.
min
(
dim
=-
1
)[
0
]
/
left_right
.
max
(
dim
=-
1
)[
0
])
*
(
top_bottom
.
min
(
dim
=-
1
)[
0
]
/
top_bottom
.
max
(
dim
=-
1
)[
0
])
return
torch
.
sqrt
(
centerness_targets
)
mmdet/models/anchor_heads/fovea_head.py
0 → 100644
View file @
57f6da5c
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
multi_apply
,
multiclass_nms
from
mmdet.ops
import
DeformConv
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
ConvModule
,
bias_init_with_prob
INF
=
1e8
class
FeatureAlign
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
deformable_groups
=
4
):
super
(
FeatureAlign
,
self
).
__init__
()
offset_channels
=
kernel_size
*
kernel_size
*
2
self
.
conv_offset
=
nn
.
Conv2d
(
4
,
deformable_groups
*
offset_channels
,
1
,
bias
=
False
)
self
.
conv_adaption
=
DeformConv
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
padding
=
(
kernel_size
-
1
)
//
2
,
deformable_groups
=
deformable_groups
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
init_weights
(
self
):
normal_init
(
self
.
conv_offset
,
std
=
0.1
)
normal_init
(
self
.
conv_adaption
,
std
=
0.01
)
def
forward
(
self
,
x
,
shape
):
offset
=
self
.
conv_offset
(
shape
)
x
=
self
.
relu
(
self
.
conv_adaption
(
x
,
offset
))
return
x
@
HEADS
.
register_module
class
FoveaHead
(
nn
.
Module
):
"""FoveaBox: Beyond Anchor-based Object Detector
https://arxiv.org/abs/1904.03797
"""
def
__init__
(
self
,
num_classes
,
in_channels
,
feat_channels
=
256
,
stacked_convs
=
4
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
base_edge_list
=
(
16
,
32
,
64
,
128
,
256
),
scale_ranges
=
((
8
,
32
),
(
16
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
)),
sigma
=
0.4
,
with_deform
=
False
,
deformable_groups
=
4
,
loss_cls
=
None
,
loss_bbox
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
super
(
FoveaHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
cls_out_channels
=
num_classes
-
1
self
.
in_channels
=
in_channels
self
.
feat_channels
=
feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
base_edge_list
=
base_edge_list
self
.
scale_ranges
=
scale_ranges
self
.
sigma
=
sigma
self
.
with_deform
=
with_deform
self
.
deformable_groups
=
deformable_groups
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_bbox
=
build_loss
(
loss_bbox
)
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
_init_layers
()
def
_init_layers
(
self
):
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
# box branch
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
reg_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
self
.
norm_cfg
is
None
))
self
.
fovea_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
4
,
3
,
padding
=
1
)
# cls branch
if
not
self
.
with_deform
:
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
cls_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
self
.
norm_cfg
is
None
))
self
.
fovea_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
cls_out_channels
,
3
,
padding
=
1
)
else
:
self
.
cls_convs
.
append
(
ConvModule
(
self
.
feat_channels
,
(
self
.
feat_channels
*
4
),
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
self
.
norm_cfg
is
None
))
self
.
cls_convs
.
append
(
ConvModule
((
self
.
feat_channels
*
4
),
(
self
.
feat_channels
*
4
),
1
,
stride
=
1
,
padding
=
0
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
self
.
norm_cfg
is
None
))
self
.
feature_adaption
=
FeatureAlign
(
self
.
feat_channels
,
self
.
feat_channels
,
kernel_size
=
3
,
deformable_groups
=
self
.
deformable_groups
)
self
.
fovea_cls
=
nn
.
Conv2d
(
int
(
self
.
feat_channels
*
4
),
self
.
cls_out_channels
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
fovea_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
fovea_reg
,
std
=
0.01
)
if
self
.
with_deform
:
self
.
feature_adaption
.
init_weights
()
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
forward_single
(
self
,
x
):
cls_feat
=
x
reg_feat
=
x
for
reg_layer
in
self
.
reg_convs
:
reg_feat
=
reg_layer
(
reg_feat
)
bbox_pred
=
self
.
fovea_reg
(
reg_feat
)
if
self
.
with_deform
:
cls_feat
=
self
.
feature_adaption
(
cls_feat
,
bbox_pred
.
exp
())
for
cls_layer
in
self
.
cls_convs
:
cls_feat
=
cls_layer
(
cls_feat
)
cls_score
=
self
.
fovea_cls
(
cls_feat
)
return
cls_score
,
bbox_pred
def
get_points
(
self
,
featmap_sizes
,
dtype
,
device
,
flatten
=
False
):
points
=
[]
for
featmap_size
in
featmap_sizes
:
x_range
=
torch
.
arange
(
featmap_size
[
1
],
dtype
=
dtype
,
device
=
device
)
+
0.5
y_range
=
torch
.
arange
(
featmap_size
[
0
],
dtype
=
dtype
,
device
=
device
)
+
0.5
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
if
flatten
:
points
.
append
((
y
.
flatten
(),
x
.
flatten
()))
else
:
points
.
append
((
y
,
x
))
return
points
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bbox_list
,
gt_label_list
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
bbox_preds
[
0
].
device
)
num_imgs
=
cls_scores
[
0
].
size
(
0
)
flatten_cls_scores
=
[
cls_score
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cls_out_channels
)
for
cls_score
in
cls_scores
]
flatten_bbox_preds
=
[
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
4
)
for
bbox_pred
in
bbox_preds
]
flatten_cls_scores
=
torch
.
cat
(
flatten_cls_scores
)
flatten_bbox_preds
=
torch
.
cat
(
flatten_bbox_preds
)
flatten_labels
,
flatten_bbox_targets
=
self
.
fovea_target
(
gt_bbox_list
,
gt_label_list
,
featmap_sizes
,
points
)
pos_inds
=
(
flatten_labels
>
0
).
nonzero
().
view
(
-
1
)
num_pos
=
len
(
pos_inds
)
loss_cls
=
self
.
loss_cls
(
flatten_cls_scores
,
flatten_labels
,
avg_factor
=
num_pos
+
num_imgs
)
if
num_pos
>
0
:
pos_bbox_preds
=
flatten_bbox_preds
[
pos_inds
]
pos_bbox_targets
=
flatten_bbox_targets
[
pos_inds
]
pos_weights
=
pos_bbox_targets
.
new_zeros
(
pos_bbox_targets
.
size
())
+
1.0
loss_bbox
=
self
.
loss_bbox
(
pos_bbox_preds
,
pos_bbox_targets
,
pos_weights
,
avg_factor
=
num_pos
)
else
:
loss_bbox
=
torch
.
tensor
([
0
],
dtype
=
flatten_bbox_preds
.
dtype
,
device
=
flatten_bbox_preds
.
device
)
return
dict
(
loss_cls
=
loss_cls
,
loss_bbox
=
loss_bbox
)
def
fovea_target
(
self
,
gt_bbox_list
,
gt_label_list
,
featmap_sizes
,
points
):
label_list
,
bbox_target_list
=
multi_apply
(
self
.
fovea_target_single
,
gt_bbox_list
,
gt_label_list
,
featmap_size_list
=
featmap_sizes
,
point_list
=
points
)
flatten_labels
=
[
torch
.
cat
([
labels_level_img
.
flatten
()
for
labels_level_img
in
labels_level
])
for
labels_level
in
zip
(
*
label_list
)
]
flatten_bbox_targets
=
[
torch
.
cat
([
bbox_targets_level_img
.
reshape
(
-
1
,
4
)
for
bbox_targets_level_img
in
bbox_targets_level
])
for
bbox_targets_level
in
zip
(
*
bbox_target_list
)
]
flatten_labels
=
torch
.
cat
(
flatten_labels
)
flatten_bbox_targets
=
torch
.
cat
(
flatten_bbox_targets
)
return
flatten_labels
,
flatten_bbox_targets
def
fovea_target_single
(
self
,
gt_bboxes_raw
,
gt_labels_raw
,
featmap_size_list
=
None
,
point_list
=
None
):
gt_areas
=
torch
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
label_list
=
[]
bbox_target_list
=
[]
# for each pyramid, find the cls and box target
for
base_len
,
(
lower_bound
,
upper_bound
),
stride
,
featmap_size
,
\
(
y
,
x
)
in
zip
(
self
.
base_edge_list
,
self
.
scale_ranges
,
self
.
strides
,
featmap_size_list
,
point_list
):
labels
=
gt_labels_raw
.
new_zeros
(
featmap_size
)
bbox_targets
=
gt_bboxes_raw
.
new
(
featmap_size
[
0
],
featmap_size
[
1
],
4
)
+
1
# scale assignment
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
().
flatten
()
if
len
(
hit_indices
)
==
0
:
label_list
.
append
(
labels
)
bbox_target_list
.
append
(
torch
.
log
(
bbox_targets
))
continue
_
,
hit_index_order
=
torch
.
sort
(
-
gt_areas
[
hit_indices
])
hit_indices
=
hit_indices
[
hit_index_order
]
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
,
:]
/
stride
gt_labels
=
gt_labels_raw
[
hit_indices
]
half_w
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
half_h
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
# valid fovea area: left, right, top, down
pos_left
=
torch
.
ceil
(
gt_bboxes
[:,
0
]
+
(
1
-
self
.
sigma
)
*
half_w
-
0.5
).
long
().
\
clamp
(
0
,
featmap_size
[
1
]
-
1
)
pos_right
=
torch
.
floor
(
gt_bboxes
[:,
0
]
+
(
1
+
self
.
sigma
)
*
half_w
-
0.5
).
long
().
\
clamp
(
0
,
featmap_size
[
1
]
-
1
)
pos_top
=
torch
.
ceil
(
gt_bboxes
[:,
1
]
+
(
1
-
self
.
sigma
)
*
half_h
-
0.5
).
long
().
\
clamp
(
0
,
featmap_size
[
0
]
-
1
)
pos_down
=
torch
.
floor
(
gt_bboxes
[:,
1
]
+
(
1
+
self
.
sigma
)
*
half_h
-
0.5
).
long
().
\
clamp
(
0
,
featmap_size
[
0
]
-
1
)
for
px1
,
py1
,
px2
,
py2
,
label
,
(
gt_x1
,
gt_y1
,
gt_x2
,
gt_y2
)
in
\
zip
(
pos_left
,
pos_top
,
pos_right
,
pos_down
,
gt_labels
,
gt_bboxes_raw
[
hit_indices
,
:]):
labels
[
py1
:
py2
+
1
,
px1
:
px2
+
1
]
=
label
bbox_targets
[
py1
:
py2
+
1
,
px1
:
px2
+
1
,
0
]
=
\
(
stride
*
x
[
py1
:
py2
+
1
,
px1
:
px2
+
1
]
-
gt_x1
)
/
base_len
bbox_targets
[
py1
:
py2
+
1
,
px1
:
px2
+
1
,
1
]
=
\
(
stride
*
y
[
py1
:
py2
+
1
,
px1
:
px2
+
1
]
-
gt_y1
)
/
base_len
bbox_targets
[
py1
:
py2
+
1
,
px1
:
px2
+
1
,
2
]
=
\
(
gt_x2
-
stride
*
x
[
py1
:
py2
+
1
,
px1
:
px2
+
1
])
/
base_len
bbox_targets
[
py1
:
py2
+
1
,
px1
:
px2
+
1
,
3
]
=
\
(
gt_y2
-
stride
*
y
[
py1
:
py2
+
1
,
px1
:
px2
+
1
])
/
base_len
bbox_targets
=
bbox_targets
.
clamp
(
min
=
1.
/
16
,
max
=
16.
)
label_list
.
append
(
labels
)
bbox_target_list
.
append
(
torch
.
log
(
bbox_targets
))
return
label_list
,
bbox_target_list
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
cfg
,
rescale
=
None
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
num_levels
=
len
(
cls_scores
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
points
=
self
.
get_points
(
featmap_sizes
,
bbox_preds
[
0
].
dtype
,
bbox_preds
[
0
].
device
,
flatten
=
True
)
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
det_bboxes
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
featmap_sizes
,
points
,
img_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
det_bboxes
)
return
result_list
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
featmap_sizes
,
point_list
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
point_list
)
det_bboxes
=
[]
det_scores
=
[]
for
cls_score
,
bbox_pred
,
featmap_size
,
stride
,
base_len
,
(
y
,
x
)
\
in
zip
(
cls_scores
,
bbox_preds
,
featmap_sizes
,
self
.
strides
,
self
.
base_edge_list
,
point_list
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
scores
=
cls_score
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
cls_out_channels
).
sigmoid
()
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
).
exp
()
nms_pre
=
cfg
.
get
(
'nms_pre'
,
-
1
)
if
(
nms_pre
>
0
)
and
(
scores
.
shape
[
0
]
>
nms_pre
):
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
_
,
topk_inds
=
max_scores
.
topk
(
nms_pre
)
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
y
=
y
[
topk_inds
]
x
=
x
[
topk_inds
]
x1
=
(
stride
*
x
-
base_len
*
bbox_pred
[:,
0
]).
\
clamp
(
min
=
0
,
max
=
img_shape
[
1
]
-
1
)
y1
=
(
stride
*
y
-
base_len
*
bbox_pred
[:,
1
]).
\
clamp
(
min
=
0
,
max
=
img_shape
[
0
]
-
1
)
x2
=
(
stride
*
x
+
base_len
*
bbox_pred
[:,
2
]).
\
clamp
(
min
=
0
,
max
=
img_shape
[
1
]
-
1
)
y2
=
(
stride
*
y
+
base_len
*
bbox_pred
[:,
3
]).
\
clamp
(
min
=
0
,
max
=
img_shape
[
0
]
-
1
)
bboxes
=
torch
.
stack
([
x1
,
y1
,
x2
,
y2
],
-
1
)
det_bboxes
.
append
(
bboxes
)
det_scores
.
append
(
scores
)
det_bboxes
=
torch
.
cat
(
det_bboxes
)
if
rescale
:
det_bboxes
/=
det_bboxes
.
new_tensor
(
scale_factor
)
det_scores
=
torch
.
cat
(
det_scores
)
padding
=
det_scores
.
new_zeros
(
det_scores
.
shape
[
0
],
1
)
det_scores
=
torch
.
cat
([
padding
,
det_scores
],
dim
=
1
)
det_bboxes
,
det_labels
=
multiclass_nms
(
det_bboxes
,
det_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
)
return
det_bboxes
,
det_labels
mmdet/models/anchor_heads/free_anchor_retina_head.py
0 → 100644
View file @
57f6da5c
import
torch
import
torch.nn.functional
as
F
from
mmdet.core
import
bbox2delta
,
bbox_overlaps
,
delta2bbox
from
..registry
import
HEADS
from
.retina_head
import
RetinaHead
@
HEADS
.
register_module
class
FreeAnchorRetinaHead
(
RetinaHead
):
def
__init__
(
self
,
num_classes
,
in_channels
,
stacked_convs
=
4
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
conv_cfg
=
None
,
norm_cfg
=
None
,
pre_anchor_topk
=
50
,
bbox_thr
=
0.6
,
gamma
=
2.0
,
alpha
=
0.5
,
**
kwargs
):
super
(
FreeAnchorRetinaHead
,
self
).
__init__
(
num_classes
,
in_channels
,
stacked_convs
,
octave_base_scale
,
scales_per_octave
,
conv_cfg
,
norm_cfg
,
**
kwargs
)
self
.
pre_anchor_topk
=
pre_anchor_topk
self
.
bbox_thr
=
bbox_thr
self
.
gamma
=
gamma
self
.
alpha
=
alpha
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
anchor_list
,
_
=
self
.
get_anchors
(
featmap_sizes
,
img_metas
)
anchors
=
[
torch
.
cat
(
anchor
)
for
anchor
in
anchor_list
]
# concatenate each level
cls_scores
=
[
cls
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
cls
.
size
(
0
),
-
1
,
self
.
cls_out_channels
)
for
cls
in
cls_scores
]
bbox_preds
=
[
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
bbox_pred
.
size
(
0
),
-
1
,
4
)
for
bbox_pred
in
bbox_preds
]
cls_scores
=
torch
.
cat
(
cls_scores
,
dim
=
1
)
bbox_preds
=
torch
.
cat
(
bbox_preds
,
dim
=
1
)
cls_prob
=
torch
.
sigmoid
(
cls_scores
)
box_prob
=
[]
num_pos
=
0
positive_losses
=
[]
for
_
,
(
anchors_
,
gt_labels_
,
gt_bboxes_
,
cls_prob_
,
bbox_preds_
)
in
enumerate
(
zip
(
anchors
,
gt_labels
,
gt_bboxes
,
cls_prob
,
bbox_preds
)):
gt_labels_
-=
1
with
torch
.
no_grad
():
# box_localization: a_{j}^{loc}, shape: [j, 4]
pred_boxes
=
delta2bbox
(
anchors_
,
bbox_preds_
,
self
.
target_means
,
self
.
target_stds
)
# object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
object_box_iou
=
bbox_overlaps
(
gt_bboxes_
,
pred_boxes
)
# object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
t1
=
self
.
bbox_thr
t2
=
object_box_iou
.
max
(
dim
=
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
t1
+
1e-12
)
object_box_prob
=
((
object_box_iou
-
t1
)
/
(
t2
-
t1
)).
clamp
(
min
=
0
,
max
=
1
)
# object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
num_obj
=
gt_labels_
.
size
(
0
)
indices
=
torch
.
stack
(
[
torch
.
arange
(
num_obj
).
type_as
(
gt_labels_
),
gt_labels_
],
dim
=
0
)
object_cls_box_prob
=
torch
.
sparse_coo_tensor
(
indices
,
object_box_prob
)
# image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
"""
from "start" to "end" implement:
image_box_iou = torch.sparse.max(object_cls_box_prob,
dim=0).t()
"""
# start
box_cls_prob
=
torch
.
sparse
.
sum
(
object_cls_box_prob
,
dim
=
0
).
to_dense
()
indices
=
torch
.
nonzero
(
box_cls_prob
).
t_
()
if
indices
.
numel
()
==
0
:
image_box_prob
=
torch
.
zeros
(
anchors_
.
size
(
0
),
self
.
cls_out_channels
).
type_as
(
object_box_prob
)
else
:
nonzero_box_prob
=
torch
.
where
(
(
gt_labels_
.
unsqueeze
(
dim
=-
1
)
==
indices
[
0
]),
object_box_prob
[:,
indices
[
1
]],
torch
.
tensor
(
[
0
]).
type_as
(
object_box_prob
)).
max
(
dim
=
0
).
values
# upmap to shape [j, c]
image_box_prob
=
torch
.
sparse_coo_tensor
(
indices
.
flip
([
0
]),
nonzero_box_prob
,
size
=
(
anchors_
.
size
(
0
),
self
.
cls_out_channels
)).
to_dense
()
# end
box_prob
.
append
(
image_box_prob
)
# construct bags for objects
match_quality_matrix
=
bbox_overlaps
(
gt_bboxes_
,
anchors_
)
_
,
matched
=
torch
.
topk
(
match_quality_matrix
,
self
.
pre_anchor_topk
,
dim
=
1
,
sorted
=
False
)
del
match_quality_matrix
# matched_cls_prob: P_{ij}^{cls}
matched_cls_prob
=
torch
.
gather
(
cls_prob_
[
matched
],
2
,
gt_labels_
.
view
(
-
1
,
1
,
1
).
repeat
(
1
,
self
.
pre_anchor_topk
,
1
)).
squeeze
(
2
)
# matched_box_prob: P_{ij}^{loc}
matched_anchors
=
anchors_
[
matched
]
matched_object_targets
=
bbox2delta
(
matched_anchors
,
gt_bboxes_
.
unsqueeze
(
dim
=
1
).
expand_as
(
matched_anchors
),
self
.
target_means
,
self
.
target_stds
)
loss_bbox
=
self
.
loss_bbox
(
bbox_preds_
[
matched
],
matched_object_targets
,
reduction_override
=
'none'
).
sum
(
-
1
)
matched_box_prob
=
torch
.
exp
(
-
loss_bbox
)
# positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
num_pos
+=
len
(
gt_bboxes_
)
positive_losses
.
append
(
self
.
positive_bag_loss
(
matched_cls_prob
,
matched_box_prob
))
positive_loss
=
torch
.
cat
(
positive_losses
).
sum
()
/
max
(
1
,
num_pos
)
# box_prob: P{a_{j} \in A_{+}}
box_prob
=
torch
.
stack
(
box_prob
,
dim
=
0
)
# negative_loss:
# \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
negative_loss
=
self
.
negative_bag_loss
(
cls_prob
,
box_prob
).
sum
()
/
max
(
1
,
num_pos
*
self
.
pre_anchor_topk
)
losses
=
{
'positive_bag_loss'
:
positive_loss
,
'negative_bag_loss'
:
negative_loss
}
return
losses
def
positive_bag_loss
(
self
,
matched_cls_prob
,
matched_box_prob
):
# bag_prob = Mean-max(matched_prob)
matched_prob
=
matched_cls_prob
*
matched_box_prob
weight
=
1
/
torch
.
clamp
(
1
-
matched_prob
,
1e-12
,
None
)
weight
/=
weight
.
sum
(
dim
=
1
).
unsqueeze
(
dim
=-
1
)
bag_prob
=
(
weight
*
matched_prob
).
sum
(
dim
=
1
)
# positive_bag_loss = -self.alpha * log(bag_prob)
return
self
.
alpha
*
F
.
binary_cross_entropy
(
bag_prob
,
torch
.
ones_like
(
bag_prob
),
reduction
=
'none'
)
def
negative_bag_loss
(
self
,
cls_prob
,
box_prob
):
prob
=
cls_prob
*
(
1
-
box_prob
)
negative_bag_loss
=
prob
**
self
.
gamma
*
F
.
binary_cross_entropy
(
prob
,
torch
.
zeros_like
(
prob
),
reduction
=
'none'
)
return
(
1
-
self
.
alpha
)
*
negative_bag_loss
mmdet/models/anchor_heads/ga_retina_head.py
0 → 100644
View file @
57f6da5c
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmdet.ops
import
MaskedConv2d
from
..registry
import
HEADS
from
..utils
import
ConvModule
,
bias_init_with_prob
from
.guided_anchor_head
import
FeatureAdaption
,
GuidedAnchorHead
@
HEADS
.
register_module
class
GARetinaHead
(
GuidedAnchorHead
):
"""Guided-Anchor-based RetinaNet head."""
def
__init__
(
self
,
num_classes
,
in_channels
,
stacked_convs
=
4
,
conv_cfg
=
None
,
norm_cfg
=
None
,
**
kwargs
):
self
.
stacked_convs
=
stacked_convs
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
super
(
GARetinaHead
,
self
).
__init__
(
num_classes
,
in_channels
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
cls_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
self
.
reg_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
self
.
conv_loc
=
nn
.
Conv2d
(
self
.
feat_channels
,
1
,
1
)
self
.
conv_shape
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
2
,
1
)
self
.
feature_adaption_cls
=
FeatureAdaption
(
self
.
feat_channels
,
self
.
feat_channels
,
kernel_size
=
3
,
deformable_groups
=
self
.
deformable_groups
)
self
.
feature_adaption_reg
=
FeatureAdaption
(
self
.
feat_channels
,
self
.
feat_channels
,
kernel_size
=
3
,
deformable_groups
=
self
.
deformable_groups
)
self
.
retina_cls
=
MaskedConv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
3
,
padding
=
1
)
self
.
retina_reg
=
MaskedConv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
self
.
feature_adaption_cls
.
init_weights
()
self
.
feature_adaption_reg
.
init_weights
()
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
conv_loc
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
conv_shape
,
std
=
0.01
)
normal_init
(
self
.
retina_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
retina_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
cls_feat
=
x
reg_feat
=
x
for
cls_conv
in
self
.
cls_convs
:
cls_feat
=
cls_conv
(
cls_feat
)
for
reg_conv
in
self
.
reg_convs
:
reg_feat
=
reg_conv
(
reg_feat
)
loc_pred
=
self
.
conv_loc
(
cls_feat
)
shape_pred
=
self
.
conv_shape
(
reg_feat
)
cls_feat
=
self
.
feature_adaption_cls
(
cls_feat
,
shape_pred
)
reg_feat
=
self
.
feature_adaption_reg
(
reg_feat
,
shape_pred
)
if
not
self
.
training
:
mask
=
loc_pred
.
sigmoid
()[
0
]
>=
self
.
loc_filter_thr
else
:
mask
=
None
cls_score
=
self
.
retina_cls
(
cls_feat
,
mask
)
bbox_pred
=
self
.
retina_reg
(
reg_feat
,
mask
)
return
cls_score
,
bbox_pred
,
shape_pred
,
loc_pred
mmdet/models/anchor_heads/ga_rpn_head.py
0 → 100644
View file @
57f6da5c
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
delta2bbox
from
mmdet.ops
import
nms
from
..registry
import
HEADS
from
.guided_anchor_head
import
GuidedAnchorHead
@
HEADS
.
register_module
class
GARPNHead
(
GuidedAnchorHead
):
"""Guided-Anchor-based RPN head."""
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
GARPNHead
,
self
).
__init__
(
2
,
in_channels
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
rpn_conv
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
feat_channels
,
3
,
padding
=
1
)
super
(
GARPNHead
,
self
).
_init_layers
()
def
init_weights
(
self
):
normal_init
(
self
.
rpn_conv
,
std
=
0.01
)
super
(
GARPNHead
,
self
).
init_weights
()
def
forward_single
(
self
,
x
):
x
=
self
.
rpn_conv
(
x
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
(
cls_score
,
bbox_pred
,
shape_pred
,
loc_pred
)
=
super
(
GARPNHead
,
self
).
forward_single
(
x
)
return
cls_score
,
bbox_pred
,
shape_pred
,
loc_pred
def
loss
(
self
,
cls_scores
,
bbox_preds
,
shape_preds
,
loc_preds
,
gt_bboxes
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
losses
=
super
(
GARPNHead
,
self
).
loss
(
cls_scores
,
bbox_preds
,
shape_preds
,
loc_preds
,
gt_bboxes
,
None
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
return
dict
(
loss_rpn_cls
=
losses
[
'loss_cls'
],
loss_rpn_bbox
=
losses
[
'loss_bbox'
],
loss_anchor_shape
=
losses
[
'loss_shape'
],
loss_anchor_loc
=
losses
[
'loss_loc'
])
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_anchors
,
mlvl_masks
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
mlvl_proposals
=
[]
for
idx
in
range
(
len
(
cls_scores
)):
rpn_cls_score
=
cls_scores
[
idx
]
rpn_bbox_pred
=
bbox_preds
[
idx
]
anchors
=
mlvl_anchors
[
idx
]
mask
=
mlvl_masks
[
idx
]
assert
rpn_cls_score
.
size
()[
-
2
:]
==
rpn_bbox_pred
.
size
()[
-
2
:]
# if no location is kept, end.
if
mask
.
sum
()
==
0
:
continue
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
)
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
reshape
(
-
1
)
scores
=
rpn_cls_score
.
sigmoid
()
else
:
rpn_cls_score
=
rpn_cls_score
.
reshape
(
-
1
,
2
)
scores
=
rpn_cls_score
.
softmax
(
dim
=
1
)[:,
1
]
# filter scores, bbox_pred w.r.t. mask.
# anchors are filtered in get_anchors() beforehand.
scores
=
scores
[
mask
]
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)[
mask
,
:]
if
scores
.
dim
()
==
0
:
rpn_bbox_pred
=
rpn_bbox_pred
.
unsqueeze
(
0
)
anchors
=
anchors
.
unsqueeze
(
0
)
scores
=
scores
.
unsqueeze
(
0
)
# filter anchors, bbox_pred, scores w.r.t. scores
if
cfg
.
nms_pre
>
0
and
scores
.
shape
[
0
]
>
cfg
.
nms_pre
:
_
,
topk_inds
=
scores
.
topk
(
cfg
.
nms_pre
)
rpn_bbox_pred
=
rpn_bbox_pred
[
topk_inds
,
:]
anchors
=
anchors
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
]
# get proposals w.r.t. anchors and rpn_bbox_pred
proposals
=
delta2bbox
(
anchors
,
rpn_bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
# filter out too small bboxes
if
cfg
.
min_bbox_size
>
0
:
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
(
h
>=
cfg
.
min_bbox_size
)).
squeeze
()
proposals
=
proposals
[
valid_inds
,
:]
scores
=
scores
[
valid_inds
]
proposals
=
torch
.
cat
([
proposals
,
scores
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# NMS in current level
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
nms_post
,
:]
mlvl_proposals
.
append
(
proposals
)
proposals
=
torch
.
cat
(
mlvl_proposals
,
0
)
if
cfg
.
nms_across_levels
:
# NMS across multi levels
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
max_num
,
:]
else
:
scores
=
proposals
[:,
4
]
num
=
min
(
cfg
.
max_num
,
proposals
.
shape
[
0
])
_
,
topk_inds
=
scores
.
topk
(
num
)
proposals
=
proposals
[
topk_inds
,
:]
return
proposals
mmdet/models/anchor_heads/guided_anchor_head.py
0 → 100644
View file @
57f6da5c
from
__future__
import
division
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
(
AnchorGenerator
,
anchor_inside_flags
,
anchor_target
,
delta2bbox
,
force_fp32
,
ga_loc_target
,
ga_shape_target
,
multi_apply
,
multiclass_nms
)
from
mmdet.ops
import
DeformConv
,
MaskedConv2d
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
from
.anchor_head
import
AnchorHead
class
FeatureAdaption
(
nn
.
Module
):
"""Feature Adaption Module.
Feature Adaption Module is implemented based on DCN v1.
It uses anchor shape prediction rather than feature map to
predict offsets of deformable conv layer.
Args:
in_channels (int): Number of channels in the input feature map.
out_channels (int): Number of channels in the output feature map.
kernel_size (int): Deformable conv kernel size.
deformable_groups (int): Deformable conv group size.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
deformable_groups
=
4
):
super
(
FeatureAdaption
,
self
).
__init__
()
offset_channels
=
kernel_size
*
kernel_size
*
2
self
.
conv_offset
=
nn
.
Conv2d
(
2
,
deformable_groups
*
offset_channels
,
1
,
bias
=
False
)
self
.
conv_adaption
=
DeformConv
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
padding
=
(
kernel_size
-
1
)
//
2
,
deformable_groups
=
deformable_groups
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
init_weights
(
self
):
normal_init
(
self
.
conv_offset
,
std
=
0.1
)
normal_init
(
self
.
conv_adaption
,
std
=
0.01
)
def
forward
(
self
,
x
,
shape
):
offset
=
self
.
conv_offset
(
shape
.
detach
())
x
=
self
.
relu
(
self
.
conv_adaption
(
x
,
offset
))
return
x
@
HEADS
.
register_module
class
GuidedAnchorHead
(
AnchorHead
):
"""Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
This GuidedAnchorHead will predict high-quality feature guided
anchors and locations where anchors will be kept in inference.
There are mainly 3 categories of bounding-boxes.
- Sampled (9) pairs for target assignment. (approxes)
- The square boxes where the predicted anchors are based on.
(squares)
- Guided anchors.
Please refer to https://arxiv.org/abs/1901.03278 for more details.
Args:
num_classes (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels.
octave_base_scale (int): Base octave scale of each level of
feature map.
scales_per_octave (int): Number of octave scales in each level of
feature map
octave_ratios (Iterable): octave aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
anchoring_means (Iterable): Mean values of anchoring targets.
anchoring_stds (Iterable): Std values of anchoring targets.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
deformable_groups: (int): Group number of DCN in
FeatureAdaption module.
loc_filter_thr (float): Threshold to filter out unconcerned regions.
loss_loc (dict): Config of location loss.
loss_shape (dict): Config of anchor shape loss.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of bbox regression loss.
"""
def
__init__
(
self
,
num_classes
,
in_channels
,
feat_channels
=
256
,
octave_base_scale
=
8
,
scales_per_octave
=
3
,
octave_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
4
,
8
,
16
,
32
,
64
],
anchor_base_sizes
=
None
,
anchoring_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
anchoring_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
deformable_groups
=
4
,
loc_filter_thr
=
0.01
,
loss_loc
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_shape
=
dict
(
type
=
'BoundedIoULoss'
,
beta
=
0.2
,
loss_weight
=
1.0
),
loss_cls
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
True
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'SmoothL1Loss'
,
beta
=
1.0
,
loss_weight
=
1.0
)):
# yapf: disable
super
(
AnchorHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
feat_channels
=
feat_channels
self
.
octave_base_scale
=
octave_base_scale
self
.
scales_per_octave
=
scales_per_octave
self
.
octave_scales
=
octave_base_scale
*
np
.
array
(
[
2
**
(
i
/
scales_per_octave
)
for
i
in
range
(
scales_per_octave
)])
self
.
approxs_per_octave
=
len
(
self
.
octave_scales
)
*
len
(
octave_ratios
)
self
.
octave_ratios
=
octave_ratios
self
.
anchor_strides
=
anchor_strides
self
.
anchor_base_sizes
=
list
(
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
self
.
anchoring_means
=
anchoring_means
self
.
anchoring_stds
=
anchoring_stds
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
deformable_groups
=
deformable_groups
self
.
loc_filter_thr
=
loc_filter_thr
self
.
approx_generators
=
[]
self
.
square_generators
=
[]
for
anchor_base
in
self
.
anchor_base_sizes
:
# Generators for approxs
self
.
approx_generators
.
append
(
AnchorGenerator
(
anchor_base
,
self
.
octave_scales
,
self
.
octave_ratios
))
# Generators for squares
self
.
square_generators
.
append
(
AnchorGenerator
(
anchor_base
,
[
self
.
octave_base_scale
],
[
1.0
]))
# one anchor per location
self
.
num_anchors
=
1
self
.
use_sigmoid_cls
=
loss_cls
.
get
(
'use_sigmoid'
,
False
)
self
.
cls_focal_loss
=
loss_cls
[
'type'
]
in
[
'FocalLoss'
]
self
.
loc_focal_loss
=
loss_loc
[
'type'
]
in
[
'FocalLoss'
]
if
self
.
use_sigmoid_cls
:
self
.
cls_out_channels
=
self
.
num_classes
-
1
else
:
self
.
cls_out_channels
=
self
.
num_classes
# build losses
self
.
loss_loc
=
build_loss
(
loss_loc
)
self
.
loss_shape
=
build_loss
(
loss_shape
)
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_bbox
=
build_loss
(
loss_bbox
)
self
.
fp16_enabled
=
False
self
.
_init_layers
()
def
_init_layers
(
self
):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv_loc
=
nn
.
Conv2d
(
self
.
in_channels
,
1
,
1
)
self
.
conv_shape
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
num_anchors
*
2
,
1
)
self
.
feature_adaption
=
FeatureAdaption
(
self
.
in_channels
,
self
.
feat_channels
,
kernel_size
=
3
,
deformable_groups
=
self
.
deformable_groups
)
self
.
conv_cls
=
MaskedConv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
1
)
self
.
conv_reg
=
MaskedConv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
1
)
def
init_weights
(
self
):
normal_init
(
self
.
conv_cls
,
std
=
0.01
)
normal_init
(
self
.
conv_reg
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
conv_loc
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
conv_shape
,
std
=
0.01
)
self
.
feature_adaption
.
init_weights
()
def
forward_single
(
self
,
x
):
loc_pred
=
self
.
conv_loc
(
x
)
shape_pred
=
self
.
conv_shape
(
x
)
x
=
self
.
feature_adaption
(
x
,
shape_pred
)
# masked conv is only used during inference for speed-up
if
not
self
.
training
:
mask
=
loc_pred
.
sigmoid
()[
0
]
>=
self
.
loc_filter_thr
else
:
mask
=
None
cls_score
=
self
.
conv_cls
(
x
,
mask
)
bbox_pred
=
self
.
conv_reg
(
x
,
mask
)
return
cls_score
,
bbox_pred
,
shape_pred
,
loc_pred
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
get_sampled_approxs
(
self
,
featmap_sizes
,
img_metas
,
cfg
,
device
=
'cuda'
):
"""Get sampled approxs and inside flags according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): device for returned tensors
Returns:
tuple: approxes of each image, inside flags of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# approxes for one time
multi_level_approxs
=
[]
for
i
in
range
(
num_levels
):
approxs
=
self
.
approx_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
],
self
.
anchor_strides
[
i
],
device
=
device
)
multi_level_approxs
.
append
(
approxs
)
approxs_list
=
[
multi_level_approxs
for
_
in
range
(
num_imgs
)]
# for each image, we compute inside flags of multi level approxes
inside_flag_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_flags
=
[]
multi_level_approxs
=
approxs_list
[
img_id
]
for
i
in
range
(
num_levels
):
approxs
=
multi_level_approxs
[
i
]
anchor_stride
=
self
.
anchor_strides
[
i
]
feat_h
,
feat_w
=
featmap_sizes
[
i
]
h
,
w
=
img_meta
[
'pad_shape'
][:
2
]
valid_feat_h
=
min
(
int
(
np
.
ceil
(
h
/
anchor_stride
)),
feat_h
)
valid_feat_w
=
min
(
int
(
np
.
ceil
(
w
/
anchor_stride
)),
feat_w
)
flags
=
self
.
approx_generators
[
i
].
valid_flags
(
(
feat_h
,
feat_w
),
(
valid_feat_h
,
valid_feat_w
),
device
=
device
)
inside_flags_list
=
[]
for
i
in
range
(
self
.
approxs_per_octave
):
split_valid_flags
=
flags
[
i
::
self
.
approxs_per_octave
]
split_approxs
=
approxs
[
i
::
self
.
approxs_per_octave
,
:]
inside_flags
=
anchor_inside_flags
(
split_approxs
,
split_valid_flags
,
img_meta
[
'img_shape'
][:
2
],
cfg
.
allowed_border
)
inside_flags_list
.
append
(
inside_flags
)
# inside_flag for a position is true if any anchor in this
# position is true
inside_flags
=
(
torch
.
stack
(
inside_flags_list
,
0
).
sum
(
dim
=
0
)
>
0
)
multi_level_flags
.
append
(
inside_flags
)
inside_flag_list
.
append
(
multi_level_flags
)
return
approxs_list
,
inside_flag_list
def
get_anchors
(
self
,
featmap_sizes
,
shape_preds
,
loc_preds
,
img_metas
,
use_loc_filter
=
False
,
device
=
'cuda'
):
"""Get squares according to feature map sizes and guided
anchors.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
shape_preds (list[tensor]): Multi-level shape predictions.
loc_preds (list[tensor]): Multi-level location predictions.
img_metas (list[dict]): Image meta info.
use_loc_filter (bool): Use loc filter or not.
device (torch.device | str): device for returned tensors
Returns:
tuple: square approxs of each image, guided anchors of each image,
loc masks of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# squares for one time
multi_level_squares
=
[]
for
i
in
range
(
num_levels
):
squares
=
self
.
square_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
],
self
.
anchor_strides
[
i
],
device
=
device
)
multi_level_squares
.
append
(
squares
)
squares_list
=
[
multi_level_squares
for
_
in
range
(
num_imgs
)]
# for each image, we compute multi level guided anchors
guided_anchors_list
=
[]
loc_mask_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_guided_anchors
=
[]
multi_level_loc_mask
=
[]
for
i
in
range
(
num_levels
):
squares
=
squares_list
[
img_id
][
i
]
shape_pred
=
shape_preds
[
i
][
img_id
]
loc_pred
=
loc_preds
[
i
][
img_id
]
guided_anchors
,
loc_mask
=
self
.
get_guided_anchors_single
(
squares
,
shape_pred
,
loc_pred
,
use_loc_filter
=
use_loc_filter
)
multi_level_guided_anchors
.
append
(
guided_anchors
)
multi_level_loc_mask
.
append
(
loc_mask
)
guided_anchors_list
.
append
(
multi_level_guided_anchors
)
loc_mask_list
.
append
(
multi_level_loc_mask
)
return
squares_list
,
guided_anchors_list
,
loc_mask_list
def
get_guided_anchors_single
(
self
,
squares
,
shape_pred
,
loc_pred
,
use_loc_filter
=
False
):
"""Get guided anchors and loc masks for a single level.
Args:
square (tensor): Squares of a single level.
shape_pred (tensor): Shape predections of a single level.
loc_pred (tensor): Loc predections of a single level.
use_loc_filter (list[tensor]): Use loc filter or not.
Returns:
tuple: guided anchors, location masks
"""
# calculate location filtering mask
loc_pred
=
loc_pred
.
sigmoid
().
detach
()
if
use_loc_filter
:
loc_mask
=
loc_pred
>=
self
.
loc_filter_thr
else
:
loc_mask
=
loc_pred
>=
0.0
mask
=
loc_mask
.
permute
(
1
,
2
,
0
).
expand
(
-
1
,
-
1
,
self
.
num_anchors
)
mask
=
mask
.
contiguous
().
view
(
-
1
)
# calculate guided anchors
squares
=
squares
[
mask
]
anchor_deltas
=
shape_pred
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
2
).
detach
()[
mask
]
bbox_deltas
=
anchor_deltas
.
new_full
(
squares
.
size
(),
0
)
bbox_deltas
[:,
2
:]
=
anchor_deltas
guided_anchors
=
delta2bbox
(
squares
,
bbox_deltas
,
self
.
anchoring_means
,
self
.
anchoring_stds
,
wh_ratio_clip
=
1e-6
)
return
guided_anchors
,
mask
def
loss_shape_single
(
self
,
shape_pred
,
bbox_anchors
,
bbox_gts
,
anchor_weights
,
anchor_total_num
):
shape_pred
=
shape_pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
2
)
bbox_anchors
=
bbox_anchors
.
contiguous
().
view
(
-
1
,
4
)
bbox_gts
=
bbox_gts
.
contiguous
().
view
(
-
1
,
4
)
anchor_weights
=
anchor_weights
.
contiguous
().
view
(
-
1
,
4
)
bbox_deltas
=
bbox_anchors
.
new_full
(
bbox_anchors
.
size
(),
0
)
bbox_deltas
[:,
2
:]
+=
shape_pred
# filter out negative samples to speed-up weighted_bounded_iou_loss
inds
=
torch
.
nonzero
(
anchor_weights
[:,
0
]
>
0
).
squeeze
(
1
)
bbox_deltas_
=
bbox_deltas
[
inds
]
bbox_anchors_
=
bbox_anchors
[
inds
]
bbox_gts_
=
bbox_gts
[
inds
]
anchor_weights_
=
anchor_weights
[
inds
]
pred_anchors_
=
delta2bbox
(
bbox_anchors_
,
bbox_deltas_
,
self
.
anchoring_means
,
self
.
anchoring_stds
,
wh_ratio_clip
=
1e-6
)
loss_shape
=
self
.
loss_shape
(
pred_anchors_
,
bbox_gts_
,
anchor_weights_
,
avg_factor
=
anchor_total_num
)
return
loss_shape
def
loss_loc_single
(
self
,
loc_pred
,
loc_target
,
loc_weight
,
loc_avg_factor
,
cfg
):
loss_loc
=
self
.
loss_loc
(
loc_pred
.
reshape
(
-
1
,
1
),
loc_target
.
reshape
(
-
1
,
1
).
long
(),
loc_weight
.
reshape
(
-
1
,
1
),
avg_factor
=
loc_avg_factor
)
return
loss_loc
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
,
'shape_preds'
,
'loc_preds'
))
def
loss
(
self
,
cls_scores
,
bbox_preds
,
shape_preds
,
loc_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
approx_generators
)
device
=
cls_scores
[
0
].
device
# get loc targets
loc_targets
,
loc_weights
,
loc_avg_factor
=
ga_loc_target
(
gt_bboxes
,
featmap_sizes
,
self
.
octave_base_scale
,
self
.
anchor_strides
,
center_ratio
=
cfg
.
center_ratio
,
ignore_ratio
=
cfg
.
ignore_ratio
)
# get sampled approxes
approxs_list
,
inside_flag_list
=
self
.
get_sampled_approxs
(
featmap_sizes
,
img_metas
,
cfg
,
device
=
device
)
# get squares and guided anchors
squares_list
,
guided_anchors_list
,
_
=
self
.
get_anchors
(
featmap_sizes
,
shape_preds
,
loc_preds
,
img_metas
,
device
=
device
)
# get shape targets
sampling
=
False
if
not
hasattr
(
cfg
,
'ga_sampler'
)
else
True
shape_targets
=
ga_shape_target
(
approxs_list
,
inside_flag_list
,
squares_list
,
gt_bboxes
,
img_metas
,
self
.
approxs_per_octave
,
cfg
,
sampling
=
sampling
)
if
shape_targets
is
None
:
return
None
(
bbox_anchors_list
,
bbox_gts_list
,
anchor_weights_list
,
anchor_fg_num
,
anchor_bg_num
)
=
shape_targets
anchor_total_num
=
(
anchor_fg_num
if
not
sampling
else
anchor_fg_num
+
anchor_bg_num
)
# get anchor targets
sampling
=
False
if
self
.
cls_focal_loss
else
True
label_channels
=
self
.
cls_out_channels
if
self
.
use_sigmoid_cls
else
1
cls_reg_targets
=
anchor_target
(
guided_anchors_list
,
inside_flag_list
,
gt_bboxes
,
img_metas
,
self
.
target_means
,
self
.
target_stds
,
cfg
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
gt_labels_list
=
gt_labels
,
label_channels
=
label_channels
,
sampling
=
sampling
)
if
cls_reg_targets
is
None
:
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
num_total_samples
=
(
num_total_pos
if
self
.
cls_focal_loss
else
num_total_pos
+
num_total_neg
)
# get classification and bbox regression losses
losses_cls
,
losses_bbox
=
multi_apply
(
self
.
loss_single
,
cls_scores
,
bbox_preds
,
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_samples
=
num_total_samples
,
cfg
=
cfg
)
# get anchor location loss
losses_loc
=
[]
for
i
in
range
(
len
(
loc_preds
)):
loss_loc
=
self
.
loss_loc_single
(
loc_preds
[
i
],
loc_targets
[
i
],
loc_weights
[
i
],
loc_avg_factor
=
loc_avg_factor
,
cfg
=
cfg
)
losses_loc
.
append
(
loss_loc
)
# get anchor shape loss
losses_shape
=
[]
for
i
in
range
(
len
(
shape_preds
)):
loss_shape
=
self
.
loss_shape_single
(
shape_preds
[
i
],
bbox_anchors_list
[
i
],
bbox_gts_list
[
i
],
anchor_weights_list
[
i
],
anchor_total_num
=
anchor_total_num
)
losses_shape
.
append
(
loss_shape
)
return
dict
(
loss_cls
=
losses_cls
,
loss_bbox
=
losses_bbox
,
loss_shape
=
losses_shape
,
loss_loc
=
losses_loc
)
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
,
'shape_preds'
,
'loc_preds'
))
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
shape_preds
,
loc_preds
,
img_metas
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
shape_preds
)
==
len
(
loc_preds
)
num_levels
=
len
(
cls_scores
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
device
=
cls_scores
[
0
].
device
# get guided anchors
_
,
guided_anchors
,
loc_masks
=
self
.
get_anchors
(
featmap_sizes
,
shape_preds
,
loc_preds
,
img_metas
,
use_loc_filter
=
not
self
.
training
,
device
=
device
)
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
guided_anchor_list
=
[
guided_anchors
[
img_id
][
i
].
detach
()
for
i
in
range
(
num_levels
)
]
loc_mask_list
=
[
loc_masks
[
img_id
][
i
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
proposals
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
guided_anchor_list
,
loc_mask_list
,
img_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
proposals
)
return
result_list
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_anchors
,
mlvl_masks
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
mlvl_bboxes
=
[]
mlvl_scores
=
[]
for
cls_score
,
bbox_pred
,
anchors
,
mask
in
zip
(
cls_scores
,
bbox_preds
,
mlvl_anchors
,
mlvl_masks
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
# if no location is kept, end.
if
mask
.
sum
()
==
0
:
continue
# reshape scores and bbox_pred
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
cls_out_channels
)
if
self
.
use_sigmoid_cls
:
scores
=
cls_score
.
sigmoid
()
else
:
scores
=
cls_score
.
softmax
(
-
1
)
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)
# filter scores, bbox_pred w.r.t. mask.
# anchors are filtered in get_anchors() beforehand.
scores
=
scores
[
mask
,
:]
bbox_pred
=
bbox_pred
[
mask
,
:]
if
scores
.
dim
()
==
0
:
anchors
=
anchors
.
unsqueeze
(
0
)
scores
=
scores
.
unsqueeze
(
0
)
bbox_pred
=
bbox_pred
.
unsqueeze
(
0
)
# filter anchors, bbox_pred, scores w.r.t. scores
nms_pre
=
cfg
.
get
(
'nms_pre'
,
-
1
)
if
nms_pre
>
0
and
scores
.
shape
[
0
]
>
nms_pre
:
if
self
.
use_sigmoid_cls
:
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
else
:
max_scores
,
_
=
scores
[:,
1
:].
max
(
dim
=
1
)
_
,
topk_inds
=
max_scores
.
topk
(
nms_pre
)
anchors
=
anchors
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
bboxes
=
delta2bbox
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
mlvl_bboxes
=
torch
.
cat
(
mlvl_bboxes
)
if
rescale
:
mlvl_bboxes
/=
mlvl_bboxes
.
new_tensor
(
scale_factor
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
if
self
.
use_sigmoid_cls
:
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
mlvl_scores
=
torch
.
cat
([
padding
,
mlvl_scores
],
dim
=
1
)
# multi class NMS
det_bboxes
,
det_labels
=
multiclass_nms
(
mlvl_bboxes
,
mlvl_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
)
return
det_bboxes
,
det_labels
mmdet/models/anchor_heads/reppoints_head.py
0 → 100644
View file @
57f6da5c
from
__future__
import
division
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
(
PointGenerator
,
multi_apply
,
multiclass_nms
,
point_target
)
from
mmdet.ops
import
DeformConv
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
ConvModule
,
bias_init_with_prob
@
HEADS
.
register_module
class
RepPointsHead
(
nn
.
Module
):
"""RepPoint head.
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
point_feat_channels (int): Number of channels of points features.
stacked_convs (int): How many conv layers are used.
gradient_mul (float): The multiplier to gradients from
points refinement and recognition.
point_strides (Iterable): points strides.
point_base_scale (int): bbox scale for assigning labels.
loss_cls (dict): Config of classification loss.
loss_bbox_init (dict): Config of initial points loss.
loss_bbox_refine (dict): Config of points loss in refinement.
use_grid_points (bool): If we use bounding box representation, the
reppoints is represented as grid points on the bounding box.
center_init (bool): Whether to use center point assignment.
transform_method (str): The methods to transform RepPoints to bbox.
"""
# noqa: W605
def
__init__
(
self
,
num_classes
,
in_channels
,
feat_channels
=
256
,
point_feat_channels
=
256
,
stacked_convs
=
3
,
num_points
=
9
,
gradient_mul
=
0.1
,
point_strides
=
[
8
,
16
,
32
,
64
,
128
],
point_base_scale
=
4
,
conv_cfg
=
None
,
norm_cfg
=
None
,
loss_cls
=
dict
(
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
loss_weight
=
1.0
),
loss_bbox_init
=
dict
(
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
0.5
),
loss_bbox_refine
=
dict
(
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
1.0
),
use_grid_points
=
False
,
center_init
=
True
,
transform_method
=
'moment'
,
moment_mul
=
0.01
):
super
(
RepPointsHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
feat_channels
=
feat_channels
self
.
point_feat_channels
=
point_feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
num_points
=
num_points
self
.
gradient_mul
=
gradient_mul
self
.
point_base_scale
=
point_base_scale
self
.
point_strides
=
point_strides
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
use_sigmoid_cls
=
loss_cls
.
get
(
'use_sigmoid'
,
False
)
self
.
sampling
=
loss_cls
[
'type'
]
not
in
[
'FocalLoss'
]
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_bbox_init
=
build_loss
(
loss_bbox_init
)
self
.
loss_bbox_refine
=
build_loss
(
loss_bbox_refine
)
self
.
use_grid_points
=
use_grid_points
self
.
center_init
=
center_init
self
.
transform_method
=
transform_method
if
self
.
transform_method
==
'moment'
:
self
.
moment_transfer
=
nn
.
Parameter
(
data
=
torch
.
zeros
(
2
),
requires_grad
=
True
)
self
.
moment_mul
=
moment_mul
if
self
.
use_sigmoid_cls
:
self
.
cls_out_channels
=
self
.
num_classes
-
1
else
:
self
.
cls_out_channels
=
self
.
num_classes
self
.
point_generators
=
[
PointGenerator
()
for
_
in
self
.
point_strides
]
# we use deformable conv to extract points features
self
.
dcn_kernel
=
int
(
np
.
sqrt
(
num_points
))
self
.
dcn_pad
=
int
((
self
.
dcn_kernel
-
1
)
/
2
)
assert
self
.
dcn_kernel
*
self
.
dcn_kernel
==
num_points
,
\
"The points number should be a square number."
assert
self
.
dcn_kernel
%
2
==
1
,
\
"The points number should be an odd square number."
dcn_base
=
np
.
arange
(
-
self
.
dcn_pad
,
self
.
dcn_pad
+
1
).
astype
(
np
.
float64
)
dcn_base_y
=
np
.
repeat
(
dcn_base
,
self
.
dcn_kernel
)
dcn_base_x
=
np
.
tile
(
dcn_base
,
self
.
dcn_kernel
)
dcn_base_offset
=
np
.
stack
([
dcn_base_y
,
dcn_base_x
],
axis
=
1
).
reshape
(
(
-
1
))
self
.
dcn_base_offset
=
torch
.
tensor
(
dcn_base_offset
).
view
(
1
,
-
1
,
1
,
1
)
self
.
_init_layers
()
def
_init_layers
(
self
):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
cls_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
self
.
reg_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
pts_out_dim
=
4
if
self
.
use_grid_points
else
2
*
self
.
num_points
self
.
reppoints_cls_conv
=
DeformConv
(
self
.
feat_channels
,
self
.
point_feat_channels
,
self
.
dcn_kernel
,
1
,
self
.
dcn_pad
)
self
.
reppoints_cls_out
=
nn
.
Conv2d
(
self
.
point_feat_channels
,
self
.
cls_out_channels
,
1
,
1
,
0
)
self
.
reppoints_pts_init_conv
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
point_feat_channels
,
3
,
1
,
1
)
self
.
reppoints_pts_init_out
=
nn
.
Conv2d
(
self
.
point_feat_channels
,
pts_out_dim
,
1
,
1
,
0
)
self
.
reppoints_pts_refine_conv
=
DeformConv
(
self
.
feat_channels
,
self
.
point_feat_channels
,
self
.
dcn_kernel
,
1
,
self
.
dcn_pad
)
self
.
reppoints_pts_refine_out
=
nn
.
Conv2d
(
self
.
point_feat_channels
,
pts_out_dim
,
1
,
1
,
0
)
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
reppoints_cls_conv
,
std
=
0.01
)
normal_init
(
self
.
reppoints_cls_out
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
reppoints_pts_init_conv
,
std
=
0.01
)
normal_init
(
self
.
reppoints_pts_init_out
,
std
=
0.01
)
normal_init
(
self
.
reppoints_pts_refine_conv
,
std
=
0.01
)
normal_init
(
self
.
reppoints_pts_refine_out
,
std
=
0.01
)
def
points2bbox
(
self
,
pts
,
y_first
=
True
):
"""
Converting the points set into bounding box.
:param pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
:param y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
:return: each points set is converting to a bbox [x1, y1, x2, y2].
"""
pts_reshape
=
pts
.
view
(
pts
.
shape
[
0
],
-
1
,
2
,
*
pts
.
shape
[
2
:])
pts_y
=
pts_reshape
[:,
:,
0
,
...]
if
y_first
else
pts_reshape
[:,
:,
1
,
...]
pts_x
=
pts_reshape
[:,
:,
1
,
...]
if
y_first
else
pts_reshape
[:,
:,
0
,
...]
if
self
.
transform_method
==
'minmax'
:
bbox_left
=
pts_x
.
min
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox_right
=
pts_x
.
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox_up
=
pts_y
.
min
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox_bottom
=
pts_y
.
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox
=
torch
.
cat
([
bbox_left
,
bbox_up
,
bbox_right
,
bbox_bottom
],
dim
=
1
)
elif
self
.
transform_method
==
'partial_minmax'
:
pts_y
=
pts_y
[:,
:
4
,
...]
pts_x
=
pts_x
[:,
:
4
,
...]
bbox_left
=
pts_x
.
min
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox_right
=
pts_x
.
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox_up
=
pts_y
.
min
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox_bottom
=
pts_y
.
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
bbox
=
torch
.
cat
([
bbox_left
,
bbox_up
,
bbox_right
,
bbox_bottom
],
dim
=
1
)
elif
self
.
transform_method
==
'moment'
:
pts_y_mean
=
pts_y
.
mean
(
dim
=
1
,
keepdim
=
True
)
pts_x_mean
=
pts_x
.
mean
(
dim
=
1
,
keepdim
=
True
)
pts_y_std
=
torch
.
std
(
pts_y
-
pts_y_mean
,
dim
=
1
,
keepdim
=
True
)
pts_x_std
=
torch
.
std
(
pts_x
-
pts_x_mean
,
dim
=
1
,
keepdim
=
True
)
moment_transfer
=
(
self
.
moment_transfer
*
self
.
moment_mul
)
+
(
self
.
moment_transfer
.
detach
()
*
(
1
-
self
.
moment_mul
))
moment_width_transfer
=
moment_transfer
[
0
]
moment_height_transfer
=
moment_transfer
[
1
]
half_width
=
pts_x_std
*
torch
.
exp
(
moment_width_transfer
)
half_height
=
pts_y_std
*
torch
.
exp
(
moment_height_transfer
)
bbox
=
torch
.
cat
([
pts_x_mean
-
half_width
,
pts_y_mean
-
half_height
,
pts_x_mean
+
half_width
,
pts_y_mean
+
half_height
],
dim
=
1
)
else
:
raise
NotImplementedError
return
bbox
def
gen_grid_from_reg
(
self
,
reg
,
previous_boxes
):
"""
Base on the previous bboxes and regression values, we compute the
regressed bboxes and generate the grids on the bboxes.
:param reg: the regression value to previous bboxes.
:param previous_boxes: previous bboxes.
:return: generate grids on the regressed bboxes.
"""
b
,
_
,
h
,
w
=
reg
.
shape
bxy
=
(
previous_boxes
[:,
:
2
,
...]
+
previous_boxes
[:,
2
:,
...])
/
2.
bwh
=
(
previous_boxes
[:,
2
:,
...]
-
previous_boxes
[:,
:
2
,
...]).
clamp
(
min
=
1e-6
)
grid_topleft
=
bxy
+
bwh
*
reg
[:,
:
2
,
...]
-
0.5
*
bwh
*
torch
.
exp
(
reg
[:,
2
:,
...])
grid_wh
=
bwh
*
torch
.
exp
(
reg
[:,
2
:,
...])
grid_left
=
grid_topleft
[:,
[
0
],
...]
grid_top
=
grid_topleft
[:,
[
1
],
...]
grid_width
=
grid_wh
[:,
[
0
],
...]
grid_height
=
grid_wh
[:,
[
1
],
...]
intervel
=
torch
.
linspace
(
0.
,
1.
,
self
.
dcn_kernel
).
view
(
1
,
self
.
dcn_kernel
,
1
,
1
).
type_as
(
reg
)
grid_x
=
grid_left
+
grid_width
*
intervel
grid_x
=
grid_x
.
unsqueeze
(
1
).
repeat
(
1
,
self
.
dcn_kernel
,
1
,
1
,
1
)
grid_x
=
grid_x
.
view
(
b
,
-
1
,
h
,
w
)
grid_y
=
grid_top
+
grid_height
*
intervel
grid_y
=
grid_y
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
self
.
dcn_kernel
,
1
,
1
)
grid_y
=
grid_y
.
view
(
b
,
-
1
,
h
,
w
)
grid_yx
=
torch
.
stack
([
grid_y
,
grid_x
],
dim
=
2
)
grid_yx
=
grid_yx
.
view
(
b
,
-
1
,
h
,
w
)
regressed_bbox
=
torch
.
cat
([
grid_left
,
grid_top
,
grid_left
+
grid_width
,
grid_top
+
grid_height
],
1
)
return
grid_yx
,
regressed_bbox
def
forward_single
(
self
,
x
):
dcn_base_offset
=
self
.
dcn_base_offset
.
type_as
(
x
)
# If we use center_init, the initial reppoints is from center points.
# If we use bounding bbox representation, the initial reppoints is
# from regular grid placed on a pre-defined bbox.
if
self
.
use_grid_points
or
not
self
.
center_init
:
scale
=
self
.
point_base_scale
/
2
points_init
=
dcn_base_offset
/
dcn_base_offset
.
max
()
*
scale
bbox_init
=
x
.
new_tensor
([
-
scale
,
-
scale
,
scale
,
scale
]).
view
(
1
,
4
,
1
,
1
)
else
:
points_init
=
0
cls_feat
=
x
pts_feat
=
x
for
cls_conv
in
self
.
cls_convs
:
cls_feat
=
cls_conv
(
cls_feat
)
for
reg_conv
in
self
.
reg_convs
:
pts_feat
=
reg_conv
(
pts_feat
)
# initialize reppoints
pts_out_init
=
self
.
reppoints_pts_init_out
(
self
.
relu
(
self
.
reppoints_pts_init_conv
(
pts_feat
)))
if
self
.
use_grid_points
:
pts_out_init
,
bbox_out_init
=
self
.
gen_grid_from_reg
(
pts_out_init
,
bbox_init
.
detach
())
else
:
pts_out_init
=
pts_out_init
+
points_init
# refine and classify reppoints
pts_out_init_grad_mul
=
(
1
-
self
.
gradient_mul
)
*
pts_out_init
.
detach
(
)
+
self
.
gradient_mul
*
pts_out_init
dcn_offset
=
pts_out_init_grad_mul
-
dcn_base_offset
cls_out
=
self
.
reppoints_cls_out
(
self
.
relu
(
self
.
reppoints_cls_conv
(
cls_feat
,
dcn_offset
)))
pts_out_refine
=
self
.
reppoints_pts_refine_out
(
self
.
relu
(
self
.
reppoints_pts_refine_conv
(
pts_feat
,
dcn_offset
)))
if
self
.
use_grid_points
:
pts_out_refine
,
bbox_out_refine
=
self
.
gen_grid_from_reg
(
pts_out_refine
,
bbox_out_init
.
detach
())
else
:
pts_out_refine
=
pts_out_refine
+
pts_out_init
.
detach
()
return
cls_out
,
pts_out_init
,
pts_out_refine
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
get_points
(
self
,
featmap_sizes
,
img_metas
):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: points of each image, valid flags of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# points center for one time
multi_level_points
=
[]
for
i
in
range
(
num_levels
):
points
=
self
.
point_generators
[
i
].
grid_points
(
featmap_sizes
[
i
],
self
.
point_strides
[
i
])
multi_level_points
.
append
(
points
)
points_list
=
[[
point
.
clone
()
for
point
in
multi_level_points
]
for
_
in
range
(
num_imgs
)]
# for each image, we compute valid flags of multi level grids
valid_flag_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_flags
=
[]
for
i
in
range
(
num_levels
):
point_stride
=
self
.
point_strides
[
i
]
feat_h
,
feat_w
=
featmap_sizes
[
i
]
h
,
w
=
img_meta
[
'pad_shape'
][:
2
]
valid_feat_h
=
min
(
int
(
np
.
ceil
(
h
/
point_stride
)),
feat_h
)
valid_feat_w
=
min
(
int
(
np
.
ceil
(
w
/
point_stride
)),
feat_w
)
flags
=
self
.
point_generators
[
i
].
valid_flags
(
(
feat_h
,
feat_w
),
(
valid_feat_h
,
valid_feat_w
))
multi_level_flags
.
append
(
flags
)
valid_flag_list
.
append
(
multi_level_flags
)
return
points_list
,
valid_flag_list
def
centers_to_bboxes
(
self
,
point_list
):
"""Get bboxes according to center points. Only used in MaxIOUAssigner.
"""
bbox_list
=
[]
for
i_img
,
point
in
enumerate
(
point_list
):
bbox
=
[]
for
i_lvl
in
range
(
len
(
self
.
point_strides
)):
scale
=
self
.
point_base_scale
*
self
.
point_strides
[
i_lvl
]
*
0.5
bbox_shift
=
torch
.
Tensor
([
-
scale
,
-
scale
,
scale
,
scale
]).
view
(
1
,
4
).
type_as
(
point
[
0
])
bbox_center
=
torch
.
cat
(
[
point
[
i_lvl
][:,
:
2
],
point
[
i_lvl
][:,
:
2
]],
dim
=
1
)
bbox
.
append
(
bbox_center
+
bbox_shift
)
bbox_list
.
append
(
bbox
)
return
bbox_list
def
offset_to_pts
(
self
,
center_list
,
pred_list
):
"""Change from point offset to point coordinate.
"""
pts_list
=
[]
for
i_lvl
in
range
(
len
(
self
.
point_strides
)):
pts_lvl
=
[]
for
i_img
in
range
(
len
(
center_list
)):
pts_center
=
center_list
[
i_img
][
i_lvl
][:,
:
2
].
repeat
(
1
,
self
.
num_points
)
pts_shift
=
pred_list
[
i_lvl
][
i_img
]
yx_pts_shift
=
pts_shift
.
permute
(
1
,
2
,
0
).
view
(
-
1
,
2
*
self
.
num_points
)
y_pts_shift
=
yx_pts_shift
[...,
0
::
2
]
x_pts_shift
=
yx_pts_shift
[...,
1
::
2
]
xy_pts_shift
=
torch
.
stack
([
x_pts_shift
,
y_pts_shift
],
-
1
)
xy_pts_shift
=
xy_pts_shift
.
view
(
*
yx_pts_shift
.
shape
[:
-
1
],
-
1
)
pts
=
xy_pts_shift
*
self
.
point_strides
[
i_lvl
]
+
pts_center
pts_lvl
.
append
(
pts
)
pts_lvl
=
torch
.
stack
(
pts_lvl
,
0
)
pts_list
.
append
(
pts_lvl
)
return
pts_list
def
loss_single
(
self
,
cls_score
,
pts_pred_init
,
pts_pred_refine
,
labels
,
label_weights
,
bbox_gt_init
,
bbox_weights_init
,
bbox_gt_refine
,
bbox_weights_refine
,
stride
,
num_total_samples_init
,
num_total_samples_refine
):
# classification loss
labels
=
labels
.
reshape
(
-
1
)
label_weights
=
label_weights
.
reshape
(
-
1
)
cls_score
=
cls_score
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cls_out_channels
)
loss_cls
=
self
.
loss_cls
(
cls_score
,
labels
,
label_weights
,
avg_factor
=
num_total_samples_refine
)
# points loss
bbox_gt_init
=
bbox_gt_init
.
reshape
(
-
1
,
4
)
bbox_weights_init
=
bbox_weights_init
.
reshape
(
-
1
,
4
)
bbox_pred_init
=
self
.
points2bbox
(
pts_pred_init
.
reshape
(
-
1
,
2
*
self
.
num_points
),
y_first
=
False
)
bbox_gt_refine
=
bbox_gt_refine
.
reshape
(
-
1
,
4
)
bbox_weights_refine
=
bbox_weights_refine
.
reshape
(
-
1
,
4
)
bbox_pred_refine
=
self
.
points2bbox
(
pts_pred_refine
.
reshape
(
-
1
,
2
*
self
.
num_points
),
y_first
=
False
)
normalize_term
=
self
.
point_base_scale
*
stride
loss_pts_init
=
self
.
loss_bbox_init
(
bbox_pred_init
/
normalize_term
,
bbox_gt_init
/
normalize_term
,
bbox_weights_init
,
avg_factor
=
num_total_samples_init
)
loss_pts_refine
=
self
.
loss_bbox_refine
(
bbox_pred_refine
/
normalize_term
,
bbox_gt_refine
/
normalize_term
,
bbox_weights_refine
,
avg_factor
=
num_total_samples_refine
)
return
loss_cls
,
loss_pts_init
,
loss_pts_refine
def
loss
(
self
,
cls_scores
,
pts_preds_init
,
pts_preds_refine
,
gt_bboxes
,
gt_labels
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
point_generators
)
label_channels
=
self
.
cls_out_channels
if
self
.
use_sigmoid_cls
else
1
# target for initial stage
center_list
,
valid_flag_list
=
self
.
get_points
(
featmap_sizes
,
img_metas
)
pts_coordinate_preds_init
=
self
.
offset_to_pts
(
center_list
,
pts_preds_init
)
if
cfg
.
init
.
assigner
[
'type'
]
==
'PointAssigner'
:
# Assign target for center list
candidate_list
=
center_list
else
:
# transform center list to bbox list and
# assign target for bbox list
bbox_list
=
self
.
centers_to_bboxes
(
center_list
)
candidate_list
=
bbox_list
cls_reg_targets_init
=
point_target
(
candidate_list
,
valid_flag_list
,
gt_bboxes
,
img_metas
,
cfg
.
init
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
gt_labels_list
=
gt_labels
,
label_channels
=
label_channels
,
sampling
=
self
.
sampling
)
(
*
_
,
bbox_gt_list_init
,
candidate_list_init
,
bbox_weights_list_init
,
num_total_pos_init
,
num_total_neg_init
)
=
cls_reg_targets_init
num_total_samples_init
=
(
num_total_pos_init
+
num_total_neg_init
if
self
.
sampling
else
num_total_pos_init
)
# target for refinement stage
center_list
,
valid_flag_list
=
self
.
get_points
(
featmap_sizes
,
img_metas
)
pts_coordinate_preds_refine
=
self
.
offset_to_pts
(
center_list
,
pts_preds_refine
)
bbox_list
=
[]
for
i_img
,
center
in
enumerate
(
center_list
):
bbox
=
[]
for
i_lvl
in
range
(
len
(
pts_preds_refine
)):
bbox_preds_init
=
self
.
points2bbox
(
pts_preds_init
[
i_lvl
].
detach
())
bbox_shift
=
bbox_preds_init
*
self
.
point_strides
[
i_lvl
]
bbox_center
=
torch
.
cat
(
[
center
[
i_lvl
][:,
:
2
],
center
[
i_lvl
][:,
:
2
]],
dim
=
1
)
bbox
.
append
(
bbox_center
+
bbox_shift
[
i_img
].
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
))
bbox_list
.
append
(
bbox
)
cls_reg_targets_refine
=
point_target
(
bbox_list
,
valid_flag_list
,
gt_bboxes
,
img_metas
,
cfg
.
refine
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
gt_labels_list
=
gt_labels
,
label_channels
=
label_channels
,
sampling
=
self
.
sampling
)
(
labels_list
,
label_weights_list
,
bbox_gt_list_refine
,
candidate_list_refine
,
bbox_weights_list_refine
,
num_total_pos_refine
,
num_total_neg_refine
)
=
cls_reg_targets_refine
num_total_samples_refine
=
(
num_total_pos_refine
+
num_total_neg_refine
if
self
.
sampling
else
num_total_pos_refine
)
# compute loss
losses_cls
,
losses_pts_init
,
losses_pts_refine
=
multi_apply
(
self
.
loss_single
,
cls_scores
,
pts_coordinate_preds_init
,
pts_coordinate_preds_refine
,
labels_list
,
label_weights_list
,
bbox_gt_list_init
,
bbox_weights_list_init
,
bbox_gt_list_refine
,
bbox_weights_list_refine
,
self
.
point_strides
,
num_total_samples_init
=
num_total_samples_init
,
num_total_samples_refine
=
num_total_samples_refine
)
loss_dict_all
=
{
'loss_cls'
:
losses_cls
,
'loss_pts_init'
:
losses_pts_init
,
'loss_pts_refine'
:
losses_pts_refine
}
return
loss_dict_all
def
get_bboxes
(
self
,
cls_scores
,
pts_preds_init
,
pts_preds_refine
,
img_metas
,
cfg
,
rescale
=
False
,
nms
=
True
):
assert
len
(
cls_scores
)
==
len
(
pts_preds_refine
)
bbox_preds_refine
=
[
self
.
points2bbox
(
pts_pred_refine
)
for
pts_pred_refine
in
pts_preds_refine
]
num_levels
=
len
(
cls_scores
)
mlvl_points
=
[
self
.
point_generators
[
i
].
grid_points
(
cls_scores
[
i
].
size
()[
-
2
:],
self
.
point_strides
[
i
])
for
i
in
range
(
num_levels
)
]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds_refine
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
proposals
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
mlvl_points
,
img_shape
,
scale_factor
,
cfg
,
rescale
,
nms
)
result_list
.
append
(
proposals
)
return
result_list
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_points
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
,
nms
=
True
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_points
)
mlvl_bboxes
=
[]
mlvl_scores
=
[]
for
i_lvl
,
(
cls_score
,
bbox_pred
,
points
)
in
enumerate
(
zip
(
cls_scores
,
bbox_preds
,
mlvl_points
)):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
cls_out_channels
)
if
self
.
use_sigmoid_cls
:
scores
=
cls_score
.
sigmoid
()
else
:
scores
=
cls_score
.
softmax
(
-
1
)
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)
nms_pre
=
cfg
.
get
(
'nms_pre'
,
-
1
)
if
nms_pre
>
0
and
scores
.
shape
[
0
]
>
nms_pre
:
if
self
.
use_sigmoid_cls
:
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
else
:
max_scores
,
_
=
scores
[:,
1
:].
max
(
dim
=
1
)
_
,
topk_inds
=
max_scores
.
topk
(
nms_pre
)
points
=
points
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
bbox_pos_center
=
torch
.
cat
([
points
[:,
:
2
],
points
[:,
:
2
]],
dim
=
1
)
bboxes
=
bbox_pred
*
self
.
point_strides
[
i_lvl
]
+
bbox_pos_center
x1
=
bboxes
[:,
0
].
clamp
(
min
=
0
,
max
=
img_shape
[
1
])
y1
=
bboxes
[:,
1
].
clamp
(
min
=
0
,
max
=
img_shape
[
0
])
x2
=
bboxes
[:,
2
].
clamp
(
min
=
0
,
max
=
img_shape
[
1
])
y2
=
bboxes
[:,
3
].
clamp
(
min
=
0
,
max
=
img_shape
[
0
])
bboxes
=
torch
.
stack
([
x1
,
y1
,
x2
,
y2
],
dim
=-
1
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
mlvl_bboxes
=
torch
.
cat
(
mlvl_bboxes
)
if
rescale
:
mlvl_bboxes
/=
mlvl_bboxes
.
new_tensor
(
scale_factor
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
if
self
.
use_sigmoid_cls
:
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
mlvl_scores
=
torch
.
cat
([
padding
,
mlvl_scores
],
dim
=
1
)
if
nms
:
det_bboxes
,
det_labels
=
multiclass_nms
(
mlvl_bboxes
,
mlvl_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
)
return
det_bboxes
,
det_labels
else
:
return
mlvl_bboxes
,
mlvl_scores
mmdet/models/anchor_heads/retina_head.py
0 → 100644
View file @
57f6da5c
import
numpy
as
np
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
..registry
import
HEADS
from
..utils
import
ConvModule
,
bias_init_with_prob
from
.anchor_head
import
AnchorHead
@
HEADS
.
register_module
class
RetinaHead
(
AnchorHead
):
"""
An anchor-based head used in [1]_.
The head contains two subnetworks. The first classifies anchor boxes and
the second regresses deltas for the anchors.
References:
.. [1] https://arxiv.org/pdf/1708.02002.pdf
Example:
>>> import torch
>>> self = RetinaHead(11, 7)
>>> x = torch.rand(1, 7, 32, 32)
>>> cls_score, bbox_pred = self.forward_single(x)
>>> # Each anchor predicts a score for each class except background
>>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
>>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
>>> assert cls_per_anchor == (self.num_classes - 1)
>>> assert box_per_anchor == 4
"""
def
__init__
(
self
,
num_classes
,
in_channels
,
stacked_convs
=
4
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
conv_cfg
=
None
,
norm_cfg
=
None
,
**
kwargs
):
self
.
stacked_convs
=
stacked_convs
self
.
octave_base_scale
=
octave_base_scale
self
.
scales_per_octave
=
scales_per_octave
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
octave_scales
=
np
.
array
(
[
2
**
(
i
/
scales_per_octave
)
for
i
in
range
(
scales_per_octave
)])
anchor_scales
=
octave_scales
*
octave_base_scale
super
(
RetinaHead
,
self
).
__init__
(
num_classes
,
in_channels
,
anchor_scales
=
anchor_scales
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
cls_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
self
.
reg_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
self
.
retina_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
3
,
padding
=
1
)
self
.
retina_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
retina_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
retina_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
cls_feat
=
x
reg_feat
=
x
for
cls_conv
in
self
.
cls_convs
:
cls_feat
=
cls_conv
(
cls_feat
)
for
reg_conv
in
self
.
reg_convs
:
reg_feat
=
reg_conv
(
reg_feat
)
cls_score
=
self
.
retina_cls
(
cls_feat
)
bbox_pred
=
self
.
retina_reg
(
reg_feat
)
return
cls_score
,
bbox_pred
mmdet/models/anchor_heads/retina_sepbn_head.py
0 → 100644
View file @
57f6da5c
import
numpy
as
np
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
..registry
import
HEADS
from
..utils
import
ConvModule
,
bias_init_with_prob
from
.anchor_head
import
AnchorHead
@
HEADS
.
register_module
class
RetinaSepBNHead
(
AnchorHead
):
""""RetinaHead with separate BN.
In RetinaHead, conv/norm layers are shared across different FPN levels,
while in RetinaSepBNHead, conv layers are shared across different FPN
levels, but BN layers are separated.
"""
def
__init__
(
self
,
num_classes
,
num_ins
,
in_channels
,
stacked_convs
=
4
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
conv_cfg
=
None
,
norm_cfg
=
None
,
**
kwargs
):
self
.
stacked_convs
=
stacked_convs
self
.
octave_base_scale
=
octave_base_scale
self
.
scales_per_octave
=
scales_per_octave
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
num_ins
=
num_ins
octave_scales
=
np
.
array
(
[
2
**
(
i
/
scales_per_octave
)
for
i
in
range
(
scales_per_octave
)])
anchor_scales
=
octave_scales
*
octave_base_scale
super
(
RetinaSepBNHead
,
self
).
__init__
(
num_classes
,
in_channels
,
anchor_scales
=
anchor_scales
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
num_ins
):
cls_convs
=
nn
.
ModuleList
()
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
cls_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
reg_convs
.
append
(
ConvModule
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
))
self
.
cls_convs
.
append
(
cls_convs
)
self
.
reg_convs
.
append
(
reg_convs
)
for
i
in
range
(
self
.
stacked_convs
):
for
j
in
range
(
1
,
self
.
num_ins
):
self
.
cls_convs
[
j
][
i
].
conv
=
self
.
cls_convs
[
0
][
i
].
conv
self
.
reg_convs
[
j
][
i
].
conv
=
self
.
reg_convs
[
0
][
i
].
conv
self
.
retina_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
3
,
padding
=
1
)
self
.
retina_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
[
0
]:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
reg_convs
[
0
]:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
retina_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
retina_reg
,
std
=
0.01
)
def
forward
(
self
,
feats
):
cls_scores
=
[]
bbox_preds
=
[]
for
i
,
x
in
enumerate
(
feats
):
cls_feat
=
feats
[
i
]
reg_feat
=
feats
[
i
]
for
cls_conv
in
self
.
cls_convs
[
i
]:
cls_feat
=
cls_conv
(
cls_feat
)
for
reg_conv
in
self
.
reg_convs
[
i
]:
reg_feat
=
reg_conv
(
reg_feat
)
cls_score
=
self
.
retina_cls
(
cls_feat
)
bbox_pred
=
self
.
retina_reg
(
reg_feat
)
cls_scores
.
append
(
cls_score
)
bbox_preds
.
append
(
bbox_pred
)
return
cls_scores
,
bbox_preds
mmdet/models/anchor_heads/rpn_head.py
0 → 100644
View file @
57f6da5c
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
delta2bbox
from
mmdet.ops
import
nms
from
..registry
import
HEADS
from
.anchor_head
import
AnchorHead
@
HEADS
.
register_module
class
RPNHead
(
AnchorHead
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
RPNHead
,
self
).
__init__
(
2
,
in_channels
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
rpn_conv
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
feat_channels
,
3
,
padding
=
1
)
self
.
rpn_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
1
)
self
.
rpn_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
1
)
def
init_weights
(
self
):
normal_init
(
self
.
rpn_conv
,
std
=
0.01
)
normal_init
(
self
.
rpn_cls
,
std
=
0.01
)
normal_init
(
self
.
rpn_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
x
=
self
.
rpn_conv
(
x
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
rpn_cls_score
=
self
.
rpn_cls
(
x
)
rpn_bbox_pred
=
self
.
rpn_reg
(
x
)
return
rpn_cls_score
,
rpn_bbox_pred
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
losses
=
super
(
RPNHead
,
self
).
loss
(
cls_scores
,
bbox_preds
,
gt_bboxes
,
None
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
return
dict
(
loss_rpn_cls
=
losses
[
'loss_cls'
],
loss_rpn_bbox
=
losses
[
'loss_bbox'
])
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
mlvl_proposals
=
[]
for
idx
in
range
(
len
(
cls_scores
)):
rpn_cls_score
=
cls_scores
[
idx
]
rpn_bbox_pred
=
bbox_preds
[
idx
]
assert
rpn_cls_score
.
size
()[
-
2
:]
==
rpn_bbox_pred
.
size
()[
-
2
:]
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
)
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
reshape
(
-
1
)
scores
=
rpn_cls_score
.
sigmoid
()
else
:
rpn_cls_score
=
rpn_cls_score
.
reshape
(
-
1
,
2
)
scores
=
rpn_cls_score
.
softmax
(
dim
=
1
)[:,
1
]
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)
anchors
=
mlvl_anchors
[
idx
]
if
cfg
.
nms_pre
>
0
and
scores
.
shape
[
0
]
>
cfg
.
nms_pre
:
_
,
topk_inds
=
scores
.
topk
(
cfg
.
nms_pre
)
rpn_bbox_pred
=
rpn_bbox_pred
[
topk_inds
,
:]
anchors
=
anchors
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
]
proposals
=
delta2bbox
(
anchors
,
rpn_bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
if
cfg
.
min_bbox_size
>
0
:
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
(
h
>=
cfg
.
min_bbox_size
)).
squeeze
()
proposals
=
proposals
[
valid_inds
,
:]
scores
=
scores
[
valid_inds
]
proposals
=
torch
.
cat
([
proposals
,
scores
.
unsqueeze
(
-
1
)],
dim
=-
1
)
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
nms_post
,
:]
mlvl_proposals
.
append
(
proposals
)
proposals
=
torch
.
cat
(
mlvl_proposals
,
0
)
if
cfg
.
nms_across_levels
:
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
max_num
,
:]
else
:
scores
=
proposals
[:,
4
]
num
=
min
(
cfg
.
max_num
,
proposals
.
shape
[
0
])
_
,
topk_inds
=
scores
.
topk
(
num
)
proposals
=
proposals
[
topk_inds
,
:]
return
proposals
mmdet/models/anchor_heads/solo_head.py
0 → 100644
View file @
57f6da5c
import
mmcv
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.ops
import
DeformConv
,
roi_align
from
mmdet.core
import
multi_apply
,
bbox2roi
,
matrix_nms
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
,
ConvModule
INF
=
1e8
def
center_of_mass
(
bitmasks
):
_
,
h
,
w
=
bitmasks
.
size
()
ys
=
torch
.
arange
(
0
,
h
,
dtype
=
torch
.
float32
,
device
=
bitmasks
.
device
)
xs
=
torch
.
arange
(
0
,
w
,
dtype
=
torch
.
float32
,
device
=
bitmasks
.
device
)
m00
=
bitmasks
.
sum
(
dim
=-
1
).
sum
(
dim
=-
1
).
clamp
(
min
=
1e-6
)
m10
=
(
bitmasks
*
xs
).
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
m01
=
(
bitmasks
*
ys
[:,
None
]).
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
center_x
=
m10
/
m00
center_y
=
m01
/
m00
return
center_x
,
center_y
def
points_nms
(
heat
,
kernel
=
2
):
# kernel must be 2
hmax
=
nn
.
functional
.
max_pool2d
(
heat
,
(
kernel
,
kernel
),
stride
=
1
,
padding
=
1
)
keep
=
(
hmax
[:,
:,
:
-
1
,
:
-
1
]
==
heat
).
float
()
return
heat
*
keep
def
dice_loss
(
input
,
target
):
input
=
input
.
contiguous
().
view
(
input
.
size
()[
0
],
-
1
)
target
=
target
.
contiguous
().
view
(
target
.
size
()[
0
],
-
1
).
float
()
a
=
torch
.
sum
(
input
*
target
,
1
)
b
=
torch
.
sum
(
input
*
input
,
1
)
+
0.001
c
=
torch
.
sum
(
target
*
target
,
1
)
+
0.001
d
=
(
2
*
a
)
/
(
b
+
c
)
return
1
-
d
@
HEADS
.
register_module
class
SOLOHead
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
,
in_channels
,
seg_feat_channels
=
256
,
stacked_convs
=
4
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
base_edge_list
=
(
16
,
32
,
64
,
128
,
256
),
scale_ranges
=
((
8
,
32
),
(
16
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
)),
sigma
=
0.4
,
num_grids
=
None
,
cate_down_pos
=
0
,
with_deform
=
False
,
loss_ins
=
None
,
loss_cate
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
super
(
SOLOHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
seg_num_grids
=
num_grids
self
.
cate_out_channels
=
self
.
num_classes
-
1
self
.
in_channels
=
in_channels
self
.
seg_feat_channels
=
seg_feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
sigma
=
sigma
self
.
cate_down_pos
=
cate_down_pos
self
.
base_edge_list
=
base_edge_list
self
.
scale_ranges
=
scale_ranges
self
.
with_deform
=
with_deform
self
.
loss_cate
=
build_loss
(
loss_cate
)
self
.
ins_loss_weight
=
loss_ins
[
'loss_weight'
]
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
_init_layers
()
def
_init_layers
(
self
):
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)
self
.
ins_convs
=
nn
.
ModuleList
()
self
.
cate_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
+
2
if
i
==
0
else
self
.
seg_feat_channels
self
.
ins_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
chn
=
self
.
in_channels
if
i
==
0
else
self
.
seg_feat_channels
self
.
cate_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
self
.
solo_ins_list
=
nn
.
ModuleList
()
for
seg_num_grid
in
self
.
seg_num_grids
:
self
.
solo_ins_list
.
append
(
nn
.
Conv2d
(
self
.
seg_feat_channels
,
seg_num_grid
**
2
,
1
))
self
.
solo_cate
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
cate_out_channels
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
ins_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
cate_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_ins
=
bias_init_with_prob
(
0.01
)
for
m
in
self
.
solo_ins_list
:
normal_init
(
m
,
std
=
0.01
,
bias
=
bias_ins
)
bias_cate
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
solo_cate
,
std
=
0.01
,
bias
=
bias_cate
)
def
forward
(
self
,
feats
,
eval
=
False
):
new_feats
=
self
.
split_feats
(
feats
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
new_feats
]
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
2
,
featmap_sizes
[
0
][
1
]
*
2
)
ins_pred
,
cate_pred
=
multi_apply
(
self
.
forward_single
,
new_feats
,
list
(
range
(
len
(
self
.
seg_num_grids
))),
eval
=
eval
,
upsampled_size
=
upsampled_size
)
return
ins_pred
,
cate_pred
def
split_feats
(
self
,
feats
):
return
(
F
.
interpolate
(
feats
[
0
],
scale_factor
=
0.5
,
mode
=
'bilinear'
),
feats
[
1
],
feats
[
2
],
feats
[
3
],
F
.
interpolate
(
feats
[
4
],
size
=
feats
[
3
].
shape
[
-
2
:],
mode
=
'bilinear'
))
def
forward_single
(
self
,
x
,
idx
,
eval
=
False
,
upsampled_size
=
None
):
ins_feat
=
x
cate_feat
=
x
# ins branch
# concat coord
x_range
=
torch
.
linspace
(
-
1
,
1
,
ins_feat
.
shape
[
-
1
],
device
=
ins_feat
.
device
)
y_range
=
torch
.
linspace
(
-
1
,
1
,
ins_feat
.
shape
[
-
2
],
device
=
ins_feat
.
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
y
=
y
.
expand
([
ins_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
x
=
x
.
expand
([
ins_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
coord_feat
=
torch
.
cat
([
x
,
y
],
1
)
ins_feat
=
torch
.
cat
([
ins_feat
,
coord_feat
],
1
)
for
i
,
ins_layer
in
enumerate
(
self
.
ins_convs
):
ins_feat
=
ins_layer
(
ins_feat
)
ins_feat
=
F
.
interpolate
(
ins_feat
,
scale_factor
=
2
,
mode
=
'bilinear'
)
ins_pred
=
self
.
solo_ins_list
[
idx
](
ins_feat
)
# cate branch
for
i
,
cate_layer
in
enumerate
(
self
.
cate_convs
):
if
i
==
self
.
cate_down_pos
:
seg_num_grid
=
self
.
seg_num_grids
[
idx
]
cate_feat
=
F
.
interpolate
(
cate_feat
,
size
=
seg_num_grid
,
mode
=
'bilinear'
)
cate_feat
=
cate_layer
(
cate_feat
)
cate_pred
=
self
.
solo_cate
(
cate_feat
)
if
eval
:
ins_pred
=
F
.
interpolate
(
ins_pred
.
sigmoid
(),
size
=
upsampled_size
,
mode
=
'bilinear'
)
cate_pred
=
points_nms
(
cate_pred
.
sigmoid
(),
kernel
=
2
).
permute
(
0
,
2
,
3
,
1
)
return
ins_pred
,
cate_pred
def
loss
(
self
,
ins_preds
,
cate_preds
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
ins_preds
]
ins_label_list
,
cate_label_list
,
ins_ind_label_list
=
multi_apply
(
self
.
solo_target_single
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
featmap_sizes
=
featmap_sizes
)
# ins
ins_labels
=
[
torch
.
cat
([
ins_labels_level_img
[
ins_ind_labels_level_img
,
...]
for
ins_labels_level_img
,
ins_ind_labels_level_img
in
zip
(
ins_labels_level
,
ins_ind_labels_level
)],
0
)
for
ins_labels_level
,
ins_ind_labels_level
in
zip
(
zip
(
*
ins_label_list
),
zip
(
*
ins_ind_label_list
))]
ins_preds
=
[
torch
.
cat
([
ins_preds_level_img
[
ins_ind_labels_level_img
,
...]
for
ins_preds_level_img
,
ins_ind_labels_level_img
in
zip
(
ins_preds_level
,
ins_ind_labels_level
)],
0
)
for
ins_preds_level
,
ins_ind_labels_level
in
zip
(
ins_preds
,
zip
(
*
ins_ind_label_list
))]
ins_ind_labels
=
[
torch
.
cat
([
ins_ind_labels_level_img
.
flatten
()
for
ins_ind_labels_level_img
in
ins_ind_labels_level
])
for
ins_ind_labels_level
in
zip
(
*
ins_ind_label_list
)
]
flatten_ins_ind_labels
=
torch
.
cat
(
ins_ind_labels
)
num_ins
=
flatten_ins_ind_labels
.
sum
()
# dice loss
loss_ins
=
[]
for
input
,
target
in
zip
(
ins_preds
,
ins_labels
):
if
input
.
size
()[
0
]
==
0
:
continue
input
=
torch
.
sigmoid
(
input
)
loss_ins
.
append
(
dice_loss
(
input
,
target
))
loss_ins
=
torch
.
cat
(
loss_ins
).
mean
()
loss_ins
=
loss_ins
*
self
.
ins_loss_weight
# cate
cate_labels
=
[
torch
.
cat
([
cate_labels_level_img
.
flatten
()
for
cate_labels_level_img
in
cate_labels_level
])
for
cate_labels_level
in
zip
(
*
cate_label_list
)
]
flatten_cate_labels
=
torch
.
cat
(
cate_labels
)
cate_preds
=
[
cate_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cate_out_channels
)
for
cate_pred
in
cate_preds
]
flatten_cate_preds
=
torch
.
cat
(
cate_preds
)
loss_cate
=
self
.
loss_cate
(
flatten_cate_preds
,
flatten_cate_labels
,
avg_factor
=
num_ins
+
1
)
return
dict
(
loss_ins
=
loss_ins
,
loss_cate
=
loss_cate
)
def
solo_target_single
(
self
,
gt_bboxes_raw
,
gt_labels_raw
,
gt_masks_raw
,
featmap_sizes
=
None
):
device
=
gt_labels_raw
[
0
].
device
# ins
gt_areas
=
torch
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_label_list
=
[]
cate_label_list
=
[]
ins_ind_label_list
=
[]
for
(
lower_bound
,
upper_bound
),
stride
,
featmap_size
,
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
strides
,
featmap_sizes
,
self
.
seg_num_grids
):
ins_label
=
torch
.
zeros
([
num_grid
**
2
,
featmap_size
[
0
],
featmap_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
cate_label
=
torch
.
zeros
([
num_grid
,
num_grid
],
dtype
=
torch
.
int64
,
device
=
device
)
ins_ind_label
=
torch
.
zeros
([
num_grid
**
2
],
dtype
=
torch
.
bool
,
device
=
device
)
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
().
flatten
()
if
len
(
hit_indices
)
==
0
:
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
.
cpu
().
numpy
(),
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
sigma
# mass center
gt_masks_pt
=
torch
.
from_numpy
(
gt_masks
).
to
(
device
=
device
)
center_ws
,
center_hs
=
center_of_mass
(
gt_masks_pt
)
valid_mask_flags
=
gt_masks_pt
.
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
>
0
output_stride
=
stride
/
2
for
seg_mask
,
gt_label
,
half_h
,
half_w
,
center_h
,
center_w
,
valid_mask_flag
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
,
center_hs
,
center_ws
,
valid_mask_flags
):
if
not
valid_mask_flag
:
continue
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
4
,
featmap_sizes
[
0
][
1
]
*
4
)
coord_w
=
int
((
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
((
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
# ins
seg_mask
=
mmcv
.
imrescale
(
seg_mask
,
scale
=
1.
/
output_stride
)
seg_mask
=
torch
.
from_numpy
(
seg_mask
).
to
(
device
=
device
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
ins_label
[
label
,
:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_ind_label
[
label
]
=
True
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
return
ins_label_list
,
cate_label_list
,
ins_ind_label_list
def
get_seg
(
self
,
seg_preds
,
cate_preds
,
img_metas
,
cfg
,
rescale
=
None
):
assert
len
(
seg_preds
)
==
len
(
cate_preds
)
num_levels
=
len
(
cate_preds
)
featmap_size
=
seg_preds
[
0
].
size
()[
-
2
:]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cate_pred_list
=
[
cate_preds
[
i
][
img_id
].
view
(
-
1
,
self
.
cate_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
seg_pred_list
=
[
seg_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
ori_shape
=
img_metas
[
img_id
][
'ori_shape'
]
cate_pred_list
=
torch
.
cat
(
cate_pred_list
,
dim
=
0
)
seg_pred_list
=
torch
.
cat
(
seg_pred_list
,
dim
=
0
)
result
=
self
.
get_seg_single
(
cate_pred_list
,
seg_pred_list
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
result
)
return
result_list
def
get_seg_single
(
self
,
cate_preds
,
seg_preds
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
=
False
,
debug
=
False
):
assert
len
(
cate_preds
)
==
len
(
seg_preds
)
# overall info.
h
,
w
,
_
=
img_shape
upsampled_size_out
=
(
featmap_size
[
0
]
*
4
,
featmap_size
[
1
]
*
4
)
# process.
inds
=
(
cate_preds
>
cfg
.
score_thr
)
# category scores.
cate_scores
=
cate_preds
[
inds
]
if
len
(
cate_scores
)
==
0
:
return
None
# category labels.
inds
=
inds
.
nonzero
()
cate_labels
=
inds
[:,
1
]
# strides.
size_trans
=
cate_labels
.
new_tensor
(
self
.
seg_num_grids
).
pow
(
2
).
cumsum
(
0
)
strides
=
cate_scores
.
new_ones
(
size_trans
[
-
1
])
n_stage
=
len
(
self
.
seg_num_grids
)
strides
[:
size_trans
[
0
]]
*=
self
.
strides
[
0
]
for
ind_
in
range
(
1
,
n_stage
):
strides
[
size_trans
[
ind_
-
1
]:
size_trans
[
ind_
]]
*=
self
.
strides
[
ind_
]
strides
=
strides
[
inds
[:,
0
]]
# masks.
seg_preds
=
seg_preds
[
inds
[:,
0
]]
seg_masks
=
seg_preds
>
cfg
.
mask_thr
sum_masks
=
seg_masks
.
sum
((
1
,
2
)).
float
()
# filter.
keep
=
sum_masks
>
strides
if
keep
.
sum
()
==
0
:
return
None
seg_masks
=
seg_masks
[
keep
,
...]
seg_preds
=
seg_preds
[
keep
,
...]
sum_masks
=
sum_masks
[
keep
]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# maskness.
seg_scores
=
(
seg_preds
*
seg_masks
.
float
()).
sum
((
1
,
2
))
/
sum_masks
cate_scores
*=
seg_scores
# sort and keep top nms_pre
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
nms_pre
:
sort_inds
=
sort_inds
[:
cfg
.
nms_pre
]
seg_masks
=
seg_masks
[
sort_inds
,
:,
:]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
sum_masks
=
sum_masks
[
sort_inds
]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
# Matrix NMS
cate_scores
=
matrix_nms
(
seg_masks
,
cate_labels
,
cate_scores
,
kernel
=
cfg
.
kernel
,
sigma
=
cfg
.
sigma
,
sum_masks
=
sum_masks
)
# filter.
keep
=
cate_scores
>=
cfg
.
update_thr
if
keep
.
sum
()
==
0
:
return
None
seg_preds
=
seg_preds
[
keep
,
:,
:]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# sort and keep top_k
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
max_per_img
:
sort_inds
=
sort_inds
[:
cfg
.
max_per_img
]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
seg_preds
=
F
.
interpolate
(
seg_preds
.
unsqueeze
(
0
),
size
=
upsampled_size_out
,
mode
=
'bilinear'
)[:,
:,
:
h
,
:
w
]
seg_masks
=
F
.
interpolate
(
seg_preds
,
size
=
ori_shape
[:
2
],
mode
=
'bilinear'
).
squeeze
(
0
)
seg_masks
=
seg_masks
>
cfg
.
mask_thr
return
seg_masks
,
cate_labels
,
cate_scores
mmdet/models/anchor_heads/solov2_head.py
0 → 100644
View file @
57f6da5c
import
mmcv
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.ops
import
DeformConv
,
roi_align
from
mmdet.core
import
multi_apply
,
matrix_nms
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
,
ConvModule
INF
=
1e8
def
center_of_mass
(
bitmasks
):
_
,
h
,
w
=
bitmasks
.
size
()
ys
=
torch
.
arange
(
0
,
h
,
dtype
=
torch
.
float32
,
device
=
bitmasks
.
device
)
xs
=
torch
.
arange
(
0
,
w
,
dtype
=
torch
.
float32
,
device
=
bitmasks
.
device
)
m00
=
bitmasks
.
sum
(
dim
=-
1
).
sum
(
dim
=-
1
).
clamp
(
min
=
1e-6
)
m10
=
(
bitmasks
*
xs
).
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
m01
=
(
bitmasks
*
ys
[:,
None
]).
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
center_x
=
m10
/
m00
center_y
=
m01
/
m00
return
center_x
,
center_y
def
points_nms
(
heat
,
kernel
=
2
):
# kernel must be 2
hmax
=
nn
.
functional
.
max_pool2d
(
heat
,
(
kernel
,
kernel
),
stride
=
1
,
padding
=
1
)
keep
=
(
hmax
[:,
:,
:
-
1
,
:
-
1
]
==
heat
).
float
()
return
heat
*
keep
def
dice_loss
(
input
,
target
):
input
=
input
.
contiguous
().
view
(
input
.
size
()[
0
],
-
1
)
target
=
target
.
contiguous
().
view
(
target
.
size
()[
0
],
-
1
).
float
()
a
=
torch
.
sum
(
input
*
target
,
1
)
b
=
torch
.
sum
(
input
*
input
,
1
)
+
0.001
c
=
torch
.
sum
(
target
*
target
,
1
)
+
0.001
d
=
(
2
*
a
)
/
(
b
+
c
)
return
1
-
d
@
HEADS
.
register_module
class
SOLOv2Head
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
,
in_channels
,
seg_feat_channels
=
256
,
stacked_convs
=
4
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
base_edge_list
=
(
16
,
32
,
64
,
128
,
256
),
scale_ranges
=
((
8
,
32
),
(
16
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
)),
sigma
=
0.2
,
num_grids
=
None
,
ins_out_channels
=
64
,
loss_ins
=
None
,
loss_cate
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
,
use_dcn_in_tower
=
False
,
type_dcn
=
None
):
super
(
SOLOv2Head
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
seg_num_grids
=
num_grids
self
.
cate_out_channels
=
self
.
num_classes
-
1
self
.
ins_out_channels
=
ins_out_channels
self
.
in_channels
=
in_channels
self
.
seg_feat_channels
=
seg_feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
sigma
=
sigma
self
.
stacked_convs
=
stacked_convs
self
.
kernel_out_channels
=
self
.
ins_out_channels
*
1
*
1
self
.
base_edge_list
=
base_edge_list
self
.
scale_ranges
=
scale_ranges
self
.
loss_cate
=
build_loss
(
loss_cate
)
self
.
ins_loss_weight
=
loss_ins
[
'loss_weight'
]
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
type_dcn
=
type_dcn
self
.
_init_layers
()
def
_init_layers
(
self
):
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)
self
.
cate_convs
=
nn
.
ModuleList
()
self
.
kernel_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
if
self
.
use_dcn_in_tower
:
cfg_conv
=
dict
(
type
=
self
.
type_dcn
)
else
:
cfg_conv
=
self
.
conv_cfg
chn
=
self
.
in_channels
+
2
if
i
==
0
else
self
.
seg_feat_channels
self
.
kernel_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
chn
=
self
.
in_channels
if
i
==
0
else
self
.
seg_feat_channels
self
.
cate_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
self
.
solo_cate
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
cate_out_channels
,
3
,
padding
=
1
)
self
.
solo_kernel
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
kernel_out_channels
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cate_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
kernel_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cate
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
solo_cate
,
std
=
0.01
,
bias
=
bias_cate
)
normal_init
(
self
.
solo_kernel
,
std
=
0.01
)
def
forward
(
self
,
feats
,
eval
=
False
):
new_feats
=
self
.
split_feats
(
feats
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
new_feats
]
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
2
,
featmap_sizes
[
0
][
1
]
*
2
)
cate_pred
,
kernel_pred
=
multi_apply
(
self
.
forward_single
,
new_feats
,
list
(
range
(
len
(
self
.
seg_num_grids
))),
eval
=
eval
,
upsampled_size
=
upsampled_size
)
return
cate_pred
,
kernel_pred
def
split_feats
(
self
,
feats
):
return
(
F
.
interpolate
(
feats
[
0
],
scale_factor
=
0.5
,
mode
=
'bilinear'
),
feats
[
1
],
feats
[
2
],
feats
[
3
],
F
.
interpolate
(
feats
[
4
],
size
=
feats
[
3
].
shape
[
-
2
:],
mode
=
'bilinear'
))
def
forward_single
(
self
,
x
,
idx
,
eval
=
False
,
upsampled_size
=
None
):
ins_kernel_feat
=
x
# ins branch
# concat coord
x_range
=
torch
.
linspace
(
-
1
,
1
,
ins_kernel_feat
.
shape
[
-
1
],
device
=
ins_kernel_feat
.
device
)
y_range
=
torch
.
linspace
(
-
1
,
1
,
ins_kernel_feat
.
shape
[
-
2
],
device
=
ins_kernel_feat
.
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
y
=
y
.
expand
([
ins_kernel_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
x
=
x
.
expand
([
ins_kernel_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
coord_feat
=
torch
.
cat
([
x
,
y
],
1
)
ins_kernel_feat
=
torch
.
cat
([
ins_kernel_feat
,
coord_feat
],
1
)
# kernel branch
kernel_feat
=
ins_kernel_feat
seg_num_grid
=
self
.
seg_num_grids
[
idx
]
kernel_feat
=
F
.
interpolate
(
kernel_feat
,
size
=
seg_num_grid
,
mode
=
'bilinear'
)
cate_feat
=
kernel_feat
[:,
:
-
2
,
:,
:]
kernel_feat
=
kernel_feat
.
contiguous
()
for
i
,
kernel_layer
in
enumerate
(
self
.
kernel_convs
):
kernel_feat
=
kernel_layer
(
kernel_feat
)
kernel_pred
=
self
.
solo_kernel
(
kernel_feat
)
# cate branch
cate_feat
=
cate_feat
.
contiguous
()
for
i
,
cate_layer
in
enumerate
(
self
.
cate_convs
):
cate_feat
=
cate_layer
(
cate_feat
)
cate_pred
=
self
.
solo_cate
(
cate_feat
)
if
eval
:
cate_pred
=
points_nms
(
cate_pred
.
sigmoid
(),
kernel
=
2
).
permute
(
0
,
2
,
3
,
1
)
return
cate_pred
,
kernel_pred
def
loss
(
self
,
cate_preds
,
kernel_preds
,
ins_pred
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
mask_feat_size
=
ins_pred
.
size
()[
-
2
:]
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
grid_order_list
=
multi_apply
(
self
.
solov2_target_single
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
mask_feat_size
=
mask_feat_size
)
# ins
ins_labels
=
[
torch
.
cat
([
ins_labels_level_img
for
ins_labels_level_img
in
ins_labels_level
],
0
)
for
ins_labels_level
in
zip
(
*
ins_label_list
)]
kernel_preds
=
[[
kernel_preds_level_img
.
view
(
kernel_preds_level_img
.
shape
[
0
],
-
1
)[:,
grid_orders_level_img
]
for
kernel_preds_level_img
,
grid_orders_level_img
in
zip
(
kernel_preds_level
,
grid_orders_level
)]
for
kernel_preds_level
,
grid_orders_level
in
zip
(
kernel_preds
,
zip
(
*
grid_order_list
))]
# generate masks
ins_pred
=
ins_pred
ins_pred_list
=
[]
for
b_kernel_pred
in
kernel_preds
:
b_mask_pred
=
[]
for
idx
,
kernel_pred
in
enumerate
(
b_kernel_pred
):
if
kernel_pred
.
size
()[
-
1
]
==
0
:
continue
cur_ins_pred
=
ins_pred
[
idx
,
...]
H
,
W
=
cur_ins_pred
.
shape
[
-
2
:]
N
,
I
=
kernel_pred
.
shape
cur_ins_pred
=
cur_ins_pred
.
unsqueeze
(
0
)
kernel_pred
=
kernel_pred
.
permute
(
1
,
0
).
view
(
I
,
-
1
,
1
,
1
)
cur_ins_pred
=
F
.
conv2d
(
cur_ins_pred
,
kernel_pred
,
stride
=
1
).
view
(
-
1
,
H
,
W
)
b_mask_pred
.
append
(
cur_ins_pred
)
if
len
(
b_mask_pred
)
==
0
:
b_mask_pred
=
None
else
:
b_mask_pred
=
torch
.
cat
(
b_mask_pred
,
0
)
ins_pred_list
.
append
(
b_mask_pred
)
ins_ind_labels
=
[
torch
.
cat
([
ins_ind_labels_level_img
.
flatten
()
for
ins_ind_labels_level_img
in
ins_ind_labels_level
])
for
ins_ind_labels_level
in
zip
(
*
ins_ind_label_list
)
]
flatten_ins_ind_labels
=
torch
.
cat
(
ins_ind_labels
)
num_ins
=
flatten_ins_ind_labels
.
sum
()
# dice loss
loss_ins
=
[]
for
input
,
target
in
zip
(
ins_pred_list
,
ins_labels
):
if
input
is
None
:
continue
input
=
torch
.
sigmoid
(
input
)
loss_ins
.
append
(
dice_loss
(
input
,
target
))
loss_ins
=
torch
.
cat
(
loss_ins
).
mean
()
loss_ins
=
loss_ins
*
self
.
ins_loss_weight
# cate
cate_labels
=
[
torch
.
cat
([
cate_labels_level_img
.
flatten
()
for
cate_labels_level_img
in
cate_labels_level
])
for
cate_labels_level
in
zip
(
*
cate_label_list
)
]
flatten_cate_labels
=
torch
.
cat
(
cate_labels
)
cate_preds
=
[
cate_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cate_out_channels
)
for
cate_pred
in
cate_preds
]
flatten_cate_preds
=
torch
.
cat
(
cate_preds
)
loss_cate
=
self
.
loss_cate
(
flatten_cate_preds
,
flatten_cate_labels
,
avg_factor
=
num_ins
+
1
)
return
dict
(
loss_ins
=
loss_ins
,
loss_cate
=
loss_cate
)
def
solov2_target_single
(
self
,
gt_bboxes_raw
,
gt_labels_raw
,
gt_masks_raw
,
mask_feat_size
):
device
=
gt_labels_raw
[
0
].
device
# ins
gt_areas
=
torch
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_label_list
=
[]
cate_label_list
=
[]
ins_ind_label_list
=
[]
grid_order_list
=
[]
for
(
lower_bound
,
upper_bound
),
stride
,
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
strides
,
self
.
seg_num_grids
):
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
().
flatten
()
num_ins
=
len
(
hit_indices
)
ins_label
=
[]
grid_order
=
[]
cate_label
=
torch
.
zeros
([
num_grid
,
num_grid
],
dtype
=
torch
.
int64
,
device
=
device
)
ins_ind_label
=
torch
.
zeros
([
num_grid
**
2
],
dtype
=
torch
.
bool
,
device
=
device
)
if
num_ins
==
0
:
ins_label
=
torch
.
zeros
([
0
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
grid_order_list
.
append
([])
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
.
cpu
().
numpy
(),
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
sigma
# mass center
gt_masks_pt
=
torch
.
from_numpy
(
gt_masks
).
to
(
device
=
device
)
center_ws
,
center_hs
=
center_of_mass
(
gt_masks_pt
)
valid_mask_flags
=
gt_masks_pt
.
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
>
0
output_stride
=
4
for
seg_mask
,
gt_label
,
half_h
,
half_w
,
center_h
,
center_w
,
valid_mask_flag
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
,
center_hs
,
center_ws
,
valid_mask_flags
):
if
not
valid_mask_flag
:
continue
upsampled_size
=
(
mask_feat_size
[
0
]
*
4
,
mask_feat_size
[
1
]
*
4
)
coord_w
=
int
((
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
((
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
seg_mask
=
mmcv
.
imrescale
(
seg_mask
,
scale
=
1.
/
output_stride
)
seg_mask
=
torch
.
from_numpy
(
seg_mask
).
to
(
device
=
device
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
cur_ins_label
=
torch
.
zeros
([
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
cur_ins_label
[:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_label
.
append
(
cur_ins_label
)
ins_ind_label
[
label
]
=
True
grid_order
.
append
(
label
)
if
len
(
ins_label
)
==
0
:
ins_label
=
torch
.
zeros
([
0
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
else
:
ins_label
=
torch
.
stack
(
ins_label
,
0
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
grid_order_list
.
append
(
grid_order
)
return
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
grid_order_list
def
get_seg
(
self
,
cate_preds
,
kernel_preds
,
seg_pred
,
img_metas
,
cfg
,
rescale
=
None
):
num_levels
=
len
(
cate_preds
)
featmap_size
=
seg_pred
.
size
()[
-
2
:]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cate_pred_list
=
[
cate_preds
[
i
][
img_id
].
view
(
-
1
,
self
.
cate_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
seg_pred_list
=
seg_pred
[
img_id
,
...].
unsqueeze
(
0
)
kernel_pred_list
=
[
kernel_preds
[
i
][
img_id
].
permute
(
1
,
2
,
0
).
view
(
-
1
,
self
.
kernel_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
ori_shape
=
img_metas
[
img_id
][
'ori_shape'
]
cate_pred_list
=
torch
.
cat
(
cate_pred_list
,
dim
=
0
)
kernel_pred_list
=
torch
.
cat
(
kernel_pred_list
,
dim
=
0
)
result
=
self
.
get_seg_single
(
cate_pred_list
,
seg_pred_list
,
kernel_pred_list
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
result
)
return
result_list
def
get_seg_single
(
self
,
cate_preds
,
seg_preds
,
kernel_preds
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
=
False
,
debug
=
False
):
assert
len
(
cate_preds
)
==
len
(
kernel_preds
)
# overall info.
h
,
w
,
_
=
img_shape
upsampled_size_out
=
(
featmap_size
[
0
]
*
4
,
featmap_size
[
1
]
*
4
)
# process.
inds
=
(
cate_preds
>
cfg
.
score_thr
)
cate_scores
=
cate_preds
[
inds
]
if
len
(
cate_scores
)
==
0
:
return
None
# cate_labels & kernel_preds
inds
=
inds
.
nonzero
()
cate_labels
=
inds
[:,
1
]
kernel_preds
=
kernel_preds
[
inds
[:,
0
]]
# trans vector.
size_trans
=
cate_labels
.
new_tensor
(
self
.
seg_num_grids
).
pow
(
2
).
cumsum
(
0
)
strides
=
kernel_preds
.
new_ones
(
size_trans
[
-
1
])
n_stage
=
len
(
self
.
seg_num_grids
)
strides
[:
size_trans
[
0
]]
*=
self
.
strides
[
0
]
for
ind_
in
range
(
1
,
n_stage
):
strides
[
size_trans
[
ind_
-
1
]:
size_trans
[
ind_
]]
*=
self
.
strides
[
ind_
]
strides
=
strides
[
inds
[:,
0
]]
# mask encoding.
I
,
N
=
kernel_preds
.
shape
kernel_preds
=
kernel_preds
.
view
(
I
,
N
,
1
,
1
)
seg_preds
=
F
.
conv2d
(
seg_preds
,
kernel_preds
,
stride
=
1
).
squeeze
(
0
).
sigmoid
()
# mask.
seg_masks
=
seg_preds
>
cfg
.
mask_thr
sum_masks
=
seg_masks
.
sum
((
1
,
2
)).
float
()
# filter.
keep
=
sum_masks
>
strides
if
keep
.
sum
()
==
0
:
return
None
seg_masks
=
seg_masks
[
keep
,
...]
seg_preds
=
seg_preds
[
keep
,
...]
sum_masks
=
sum_masks
[
keep
]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# maskness.
seg_scores
=
(
seg_preds
*
seg_masks
.
float
()).
sum
((
1
,
2
))
/
sum_masks
cate_scores
*=
seg_scores
# sort and keep top nms_pre
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
nms_pre
:
sort_inds
=
sort_inds
[:
cfg
.
nms_pre
]
seg_masks
=
seg_masks
[
sort_inds
,
:,
:]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
sum_masks
=
sum_masks
[
sort_inds
]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
# Matrix NMS
cate_scores
=
matrix_nms
(
seg_masks
,
cate_labels
,
cate_scores
,
kernel
=
cfg
.
kernel
,
sigma
=
cfg
.
sigma
,
sum_masks
=
sum_masks
)
# filter.
keep
=
cate_scores
>=
cfg
.
update_thr
if
keep
.
sum
()
==
0
:
return
None
seg_preds
=
seg_preds
[
keep
,
:,
:]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# sort and keep top_k
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
max_per_img
:
sort_inds
=
sort_inds
[:
cfg
.
max_per_img
]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
seg_preds
=
F
.
interpolate
(
seg_preds
.
unsqueeze
(
0
),
size
=
upsampled_size_out
,
mode
=
'bilinear'
)[:,
:,
:
h
,
:
w
]
seg_masks
=
F
.
interpolate
(
seg_preds
,
size
=
ori_shape
[:
2
],
mode
=
'bilinear'
).
squeeze
(
0
)
seg_masks
=
seg_masks
>
cfg
.
mask_thr
return
seg_masks
,
cate_labels
,
cate_scores
mmdet/models/anchor_heads/solov2_light_head.py
0 → 100644
View file @
57f6da5c
import
mmcv
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.ops
import
DeformConv
,
roi_align
from
mmdet.core
import
multi_apply
,
matrix_nms
from
..builder
import
build_loss
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
,
ConvModule
INF
=
1e8
def
center_of_mass
(
bitmasks
):
_
,
h
,
w
=
bitmasks
.
size
()
ys
=
torch
.
arange
(
0
,
h
,
dtype
=
torch
.
float32
,
device
=
bitmasks
.
device
)
xs
=
torch
.
arange
(
0
,
w
,
dtype
=
torch
.
float32
,
device
=
bitmasks
.
device
)
m00
=
bitmasks
.
sum
(
dim
=-
1
).
sum
(
dim
=-
1
).
clamp
(
min
=
1e-6
)
m10
=
(
bitmasks
*
xs
).
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
m01
=
(
bitmasks
*
ys
[:,
None
]).
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
center_x
=
m10
/
m00
center_y
=
m01
/
m00
return
center_x
,
center_y
def
points_nms
(
heat
,
kernel
=
2
):
# kernel must be 2
hmax
=
nn
.
functional
.
max_pool2d
(
heat
,
(
kernel
,
kernel
),
stride
=
1
,
padding
=
1
)
keep
=
(
hmax
[:,
:,
:
-
1
,
:
-
1
]
==
heat
).
float
()
return
heat
*
keep
def
dice_loss
(
input
,
target
):
input
=
input
.
contiguous
().
view
(
input
.
size
()[
0
],
-
1
)
target
=
target
.
contiguous
().
view
(
target
.
size
()[
0
],
-
1
).
float
()
a
=
torch
.
sum
(
input
*
target
,
1
)
b
=
torch
.
sum
(
input
*
input
,
1
)
+
0.001
c
=
torch
.
sum
(
target
*
target
,
1
)
+
0.001
d
=
(
2
*
a
)
/
(
b
+
c
)
return
1
-
d
@
HEADS
.
register_module
class
SOLOv2LightHead
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
,
in_channels
,
seg_feat_channels
=
256
,
strides
=
(
4
,
8
,
16
,
32
,
64
),
base_edge_list
=
(
16
,
32
,
64
,
128
,
256
),
scale_ranges
=
((
8
,
32
),
(
16
,
64
),
(
32
,
128
),
(
64
,
256
),
(
128
,
512
)),
sigma
=
0.2
,
num_grids
=
None
,
ins_out_channels
=
64
,
stacked_convs
=
4
,
loss_ins
=
None
,
loss_cate
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
,
use_dcn_in_tower
=
False
,
type_dcn
=
None
):
super
(
SOLOv2LightHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
seg_num_grids
=
num_grids
self
.
cate_out_channels
=
self
.
num_classes
-
1
self
.
ins_out_channels
=
ins_out_channels
self
.
in_channels
=
in_channels
self
.
seg_feat_channels
=
seg_feat_channels
self
.
stacked_convs
=
stacked_convs
self
.
strides
=
strides
self
.
sigma
=
sigma
self
.
stacked_convs
=
stacked_convs
self
.
kernel_out_channels
=
self
.
ins_out_channels
*
1
*
1
self
.
base_edge_list
=
base_edge_list
self
.
scale_ranges
=
scale_ranges
self
.
loss_cate
=
build_loss
(
loss_cate
)
self
.
ins_loss_weight
=
loss_ins
[
'loss_weight'
]
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
type_dcn
=
type_dcn
self
.
_init_layers
()
def
_init_layers
(
self
):
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
)
self
.
cate_convs
=
nn
.
ModuleList
()
self
.
kernel_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
if
self
.
use_dcn_in_tower
and
i
==
self
.
stacked_convs
-
1
:
cfg_conv
=
dict
(
type
=
self
.
type_dcn
)
else
:
cfg_conv
=
self
.
conv_cfg
chn
=
self
.
in_channels
+
2
if
i
==
0
else
self
.
seg_feat_channels
self
.
kernel_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
chn
=
self
.
in_channels
if
i
==
0
else
self
.
seg_feat_channels
self
.
cate_convs
.
append
(
ConvModule
(
chn
,
self
.
seg_feat_channels
,
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
cfg_conv
,
norm_cfg
=
norm_cfg
,
bias
=
norm_cfg
is
None
))
self
.
solo_cate
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
cate_out_channels
,
3
,
padding
=
1
)
self
.
solo_kernel
=
nn
.
Conv2d
(
self
.
seg_feat_channels
,
self
.
kernel_out_channels
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cate_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
for
m
in
self
.
kernel_convs
:
normal_init
(
m
.
conv
,
std
=
0.01
)
bias_cate
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
solo_cate
,
std
=
0.01
,
bias
=
bias_cate
)
normal_init
(
self
.
solo_kernel
,
std
=
0.01
)
def
forward
(
self
,
feats
,
eval
=
False
):
new_feats
=
self
.
split_feats
(
feats
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
new_feats
]
upsampled_size
=
(
featmap_sizes
[
0
][
0
]
*
2
,
featmap_sizes
[
0
][
1
]
*
2
)
cate_pred
,
kernel_pred
=
multi_apply
(
self
.
forward_single
,
new_feats
,
list
(
range
(
len
(
self
.
seg_num_grids
))),
eval
=
eval
,
upsampled_size
=
upsampled_size
)
return
cate_pred
,
kernel_pred
def
split_feats
(
self
,
feats
):
return
(
F
.
interpolate
(
feats
[
0
],
scale_factor
=
0.5
,
mode
=
'bilinear'
),
feats
[
1
],
feats
[
2
],
feats
[
3
],
F
.
interpolate
(
feats
[
4
],
size
=
feats
[
3
].
shape
[
-
2
:],
mode
=
'bilinear'
))
def
forward_single
(
self
,
x
,
idx
,
eval
=
False
,
upsampled_size
=
None
):
ins_kernel_feat
=
x
# ins branch
# concat coord
x_range
=
torch
.
linspace
(
-
1
,
1
,
ins_kernel_feat
.
shape
[
-
1
],
device
=
ins_kernel_feat
.
device
)
y_range
=
torch
.
linspace
(
-
1
,
1
,
ins_kernel_feat
.
shape
[
-
2
],
device
=
ins_kernel_feat
.
device
)
y
,
x
=
torch
.
meshgrid
(
y_range
,
x_range
)
y
=
y
.
expand
([
ins_kernel_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
x
=
x
.
expand
([
ins_kernel_feat
.
shape
[
0
],
1
,
-
1
,
-
1
])
coord_feat
=
torch
.
cat
([
x
,
y
],
1
)
ins_kernel_feat
=
torch
.
cat
([
ins_kernel_feat
,
coord_feat
],
1
)
# kernel branch
kernel_feat
=
ins_kernel_feat
seg_num_grid
=
self
.
seg_num_grids
[
idx
]
kernel_feat
=
F
.
interpolate
(
kernel_feat
,
size
=
seg_num_grid
,
mode
=
'bilinear'
)
cate_feat
=
kernel_feat
[:,
:
-
2
,
:,
:]
kernel_feat
=
kernel_feat
.
contiguous
()
for
i
,
kernel_layer
in
enumerate
(
self
.
kernel_convs
):
kernel_feat
=
kernel_layer
(
kernel_feat
)
kernel_pred
=
self
.
solo_kernel
(
kernel_feat
)
# cate branch
cate_feat
=
cate_feat
.
contiguous
()
for
i
,
cate_layer
in
enumerate
(
self
.
cate_convs
):
cate_feat
=
cate_layer
(
cate_feat
)
cate_pred
=
self
.
solo_cate
(
cate_feat
)
if
eval
:
cate_pred
=
points_nms
(
cate_pred
.
sigmoid
(),
kernel
=
2
).
permute
(
0
,
2
,
3
,
1
)
return
cate_pred
,
kernel_pred
def
loss
(
self
,
cate_preds
,
kernel_preds
,
ins_pred
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
mask_feat_size
=
ins_pred
.
size
()[
-
2
:]
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
grid_order_list
=
multi_apply
(
self
.
solov2_target_single
,
gt_bbox_list
,
gt_label_list
,
gt_mask_list
,
mask_feat_size
=
mask_feat_size
)
# ins
ins_labels
=
[
torch
.
cat
([
ins_labels_level_img
for
ins_labels_level_img
in
ins_labels_level
],
0
)
for
ins_labels_level
in
zip
(
*
ins_label_list
)]
kernel_preds
=
[[
kernel_preds_level_img
.
view
(
kernel_preds_level_img
.
shape
[
0
],
-
1
)[:,
grid_orders_level_img
]
for
kernel_preds_level_img
,
grid_orders_level_img
in
zip
(
kernel_preds_level
,
grid_orders_level
)]
for
kernel_preds_level
,
grid_orders_level
in
zip
(
kernel_preds
,
zip
(
*
grid_order_list
))]
# generate masks
ins_pred
=
ins_pred
ins_pred_list
=
[]
for
b_kernel_pred
in
kernel_preds
:
b_mask_pred
=
[]
for
idx
,
kernel_pred
in
enumerate
(
b_kernel_pred
):
if
kernel_pred
.
size
()[
-
1
]
==
0
:
continue
cur_ins_pred
=
ins_pred
[
idx
,
...]
H
,
W
=
cur_ins_pred
.
shape
[
-
2
:]
N
,
I
=
kernel_pred
.
shape
cur_ins_pred
=
cur_ins_pred
.
unsqueeze
(
0
)
kernel_pred
=
kernel_pred
.
permute
(
1
,
0
).
view
(
I
,
-
1
,
1
,
1
)
cur_ins_pred
=
F
.
conv2d
(
cur_ins_pred
,
kernel_pred
,
stride
=
1
).
view
(
-
1
,
H
,
W
)
b_mask_pred
.
append
(
cur_ins_pred
)
if
len
(
b_mask_pred
)
==
0
:
b_mask_pred
=
None
else
:
b_mask_pred
=
torch
.
cat
(
b_mask_pred
,
0
)
ins_pred_list
.
append
(
b_mask_pred
)
ins_ind_labels
=
[
torch
.
cat
([
ins_ind_labels_level_img
.
flatten
()
for
ins_ind_labels_level_img
in
ins_ind_labels_level
])
for
ins_ind_labels_level
in
zip
(
*
ins_ind_label_list
)
]
flatten_ins_ind_labels
=
torch
.
cat
(
ins_ind_labels
)
num_ins
=
flatten_ins_ind_labels
.
sum
()
# dice loss
loss_ins
=
[]
for
input
,
target
in
zip
(
ins_pred_list
,
ins_labels
):
if
input
is
None
:
continue
input
=
torch
.
sigmoid
(
input
)
loss_ins
.
append
(
dice_loss
(
input
,
target
))
loss_ins
=
torch
.
cat
(
loss_ins
).
mean
()
loss_ins
=
loss_ins
*
self
.
ins_loss_weight
# cate
cate_labels
=
[
torch
.
cat
([
cate_labels_level_img
.
flatten
()
for
cate_labels_level_img
in
cate_labels_level
])
for
cate_labels_level
in
zip
(
*
cate_label_list
)
]
flatten_cate_labels
=
torch
.
cat
(
cate_labels
)
cate_preds
=
[
cate_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cate_out_channels
)
for
cate_pred
in
cate_preds
]
flatten_cate_preds
=
torch
.
cat
(
cate_preds
)
loss_cate
=
self
.
loss_cate
(
flatten_cate_preds
,
flatten_cate_labels
,
avg_factor
=
num_ins
+
1
)
return
dict
(
loss_ins
=
loss_ins
,
loss_cate
=
loss_cate
)
def
solov2_target_single
(
self
,
gt_bboxes_raw
,
gt_labels_raw
,
gt_masks_raw
,
mask_feat_size
):
device
=
gt_labels_raw
[
0
].
device
# ins
gt_areas
=
torch
.
sqrt
((
gt_bboxes_raw
[:,
2
]
-
gt_bboxes_raw
[:,
0
])
*
(
gt_bboxes_raw
[:,
3
]
-
gt_bboxes_raw
[:,
1
]))
ins_label_list
=
[]
cate_label_list
=
[]
ins_ind_label_list
=
[]
grid_order_list
=
[]
for
(
lower_bound
,
upper_bound
),
stride
,
num_grid
\
in
zip
(
self
.
scale_ranges
,
self
.
strides
,
self
.
seg_num_grids
):
hit_indices
=
((
gt_areas
>=
lower_bound
)
&
(
gt_areas
<=
upper_bound
)).
nonzero
().
flatten
()
num_ins
=
len
(
hit_indices
)
ins_label
=
[]
grid_order
=
[]
cate_label
=
torch
.
zeros
([
num_grid
,
num_grid
],
dtype
=
torch
.
int64
,
device
=
device
)
ins_ind_label
=
torch
.
zeros
([
num_grid
**
2
],
dtype
=
torch
.
bool
,
device
=
device
)
if
num_ins
==
0
:
ins_label
=
torch
.
zeros
([
0
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
grid_order_list
.
append
([])
continue
gt_bboxes
=
gt_bboxes_raw
[
hit_indices
]
gt_labels
=
gt_labels_raw
[
hit_indices
]
gt_masks
=
gt_masks_raw
[
hit_indices
.
cpu
().
numpy
(),
...]
half_ws
=
0.5
*
(
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
])
*
self
.
sigma
half_hs
=
0.5
*
(
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
])
*
self
.
sigma
# mass center
gt_masks_pt
=
torch
.
from_numpy
(
gt_masks
).
to
(
device
=
device
)
center_ws
,
center_hs
=
center_of_mass
(
gt_masks_pt
)
valid_mask_flags
=
gt_masks_pt
.
sum
(
dim
=-
1
).
sum
(
dim
=-
1
)
>
0
output_stride
=
4
for
seg_mask
,
gt_label
,
half_h
,
half_w
,
center_h
,
center_w
,
valid_mask_flag
in
zip
(
gt_masks
,
gt_labels
,
half_hs
,
half_ws
,
center_hs
,
center_ws
,
valid_mask_flags
):
if
not
valid_mask_flag
:
continue
upsampled_size
=
(
mask_feat_size
[
0
]
*
4
,
mask_feat_size
[
1
]
*
4
)
coord_w
=
int
((
center_w
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
))
coord_h
=
int
((
center_h
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
))
# left, top, right, down
top_box
=
max
(
0
,
int
(((
center_h
-
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
down_box
=
min
(
num_grid
-
1
,
int
(((
center_h
+
half_h
)
/
upsampled_size
[
0
])
//
(
1.
/
num_grid
)))
left_box
=
max
(
0
,
int
(((
center_w
-
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
right_box
=
min
(
num_grid
-
1
,
int
(((
center_w
+
half_w
)
/
upsampled_size
[
1
])
//
(
1.
/
num_grid
)))
top
=
max
(
top_box
,
coord_h
-
1
)
down
=
min
(
down_box
,
coord_h
+
1
)
left
=
max
(
coord_w
-
1
,
left_box
)
right
=
min
(
right_box
,
coord_w
+
1
)
cate_label
[
top
:(
down
+
1
),
left
:(
right
+
1
)]
=
gt_label
seg_mask
=
mmcv
.
imrescale
(
seg_mask
,
scale
=
1.
/
output_stride
)
seg_mask
=
torch
.
from_numpy
(
seg_mask
).
to
(
device
=
device
)
for
i
in
range
(
top
,
down
+
1
):
for
j
in
range
(
left
,
right
+
1
):
label
=
int
(
i
*
num_grid
+
j
)
cur_ins_label
=
torch
.
zeros
([
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
cur_ins_label
[:
seg_mask
.
shape
[
0
],
:
seg_mask
.
shape
[
1
]]
=
seg_mask
ins_label
.
append
(
cur_ins_label
)
ins_ind_label
[
label
]
=
True
grid_order
.
append
(
label
)
if
len
(
ins_label
)
==
0
:
ins_label
=
torch
.
zeros
([
0
,
mask_feat_size
[
0
],
mask_feat_size
[
1
]],
dtype
=
torch
.
uint8
,
device
=
device
)
else
:
ins_label
=
torch
.
stack
(
ins_label
,
0
)
ins_label_list
.
append
(
ins_label
)
cate_label_list
.
append
(
cate_label
)
ins_ind_label_list
.
append
(
ins_ind_label
)
grid_order_list
.
append
(
grid_order
)
return
ins_label_list
,
cate_label_list
,
ins_ind_label_list
,
grid_order_list
def
get_seg
(
self
,
cate_preds
,
kernel_preds
,
seg_pred
,
img_metas
,
cfg
,
rescale
=
None
):
num_levels
=
len
(
cate_preds
)
featmap_size
=
seg_pred
.
size
()[
-
2
:]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cate_pred_list
=
[
cate_preds
[
i
][
img_id
].
view
(
-
1
,
self
.
cate_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
seg_pred_list
=
seg_pred
[
img_id
,
...].
unsqueeze
(
0
)
kernel_pred_list
=
[
kernel_preds
[
i
][
img_id
].
permute
(
1
,
2
,
0
).
view
(
-
1
,
self
.
kernel_out_channels
).
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
ori_shape
=
img_metas
[
img_id
][
'ori_shape'
]
cate_pred_list
=
torch
.
cat
(
cate_pred_list
,
dim
=
0
)
kernel_pred_list
=
torch
.
cat
(
kernel_pred_list
,
dim
=
0
)
result
=
self
.
get_seg_single
(
cate_pred_list
,
seg_pred_list
,
kernel_pred_list
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
result
)
return
result_list
def
get_seg_single
(
self
,
cate_preds
,
seg_preds
,
kernel_preds
,
featmap_size
,
img_shape
,
ori_shape
,
scale_factor
,
cfg
,
rescale
=
False
,
debug
=
False
):
assert
len
(
cate_preds
)
==
len
(
kernel_preds
)
# overall info.
h
,
w
,
_
=
img_shape
upsampled_size_out
=
(
featmap_size
[
0
]
*
4
,
featmap_size
[
1
]
*
4
)
# process.
inds
=
(
cate_preds
>
cfg
.
score_thr
)
cate_scores
=
cate_preds
[
inds
]
if
len
(
cate_scores
)
==
0
:
return
None
# cate_labels & kernel_preds
inds
=
inds
.
nonzero
()
cate_labels
=
inds
[:,
1
]
kernel_preds
=
kernel_preds
[
inds
[:,
0
]]
# trans vector.
size_trans
=
cate_labels
.
new_tensor
(
self
.
seg_num_grids
).
pow
(
2
).
cumsum
(
0
)
strides
=
kernel_preds
.
new_ones
(
size_trans
[
-
1
])
n_stage
=
len
(
self
.
seg_num_grids
)
strides
[:
size_trans
[
0
]]
*=
self
.
strides
[
0
]
for
ind_
in
range
(
1
,
n_stage
):
strides
[
size_trans
[
ind_
-
1
]:
size_trans
[
ind_
]]
*=
self
.
strides
[
ind_
]
strides
=
strides
[
inds
[:,
0
]]
# mask encoding.
I
,
N
=
kernel_preds
.
shape
kernel_preds
=
kernel_preds
.
view
(
I
,
N
,
1
,
1
)
seg_preds
=
F
.
conv2d
(
seg_preds
,
kernel_preds
,
stride
=
1
).
squeeze
(
0
).
sigmoid
()
# mask.
seg_masks
=
seg_preds
>
cfg
.
mask_thr
sum_masks
=
seg_masks
.
sum
((
1
,
2
)).
float
()
# filter.
keep
=
sum_masks
>
strides
if
keep
.
sum
()
==
0
:
return
None
seg_masks
=
seg_masks
[
keep
,
...]
seg_preds
=
seg_preds
[
keep
,
...]
sum_masks
=
sum_masks
[
keep
]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# maskness.
seg_scores
=
(
seg_preds
*
seg_masks
.
float
()).
sum
((
1
,
2
))
/
sum_masks
cate_scores
*=
seg_scores
# sort and keep top nms_pre
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
nms_pre
:
sort_inds
=
sort_inds
[:
cfg
.
nms_pre
]
seg_masks
=
seg_masks
[
sort_inds
,
:,
:]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
sum_masks
=
sum_masks
[
sort_inds
]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
# Matrix NMS
cate_scores
=
matrix_nms
(
seg_masks
,
cate_labels
,
cate_scores
,
kernel
=
cfg
.
kernel
,
sigma
=
cfg
.
sigma
,
sum_masks
=
sum_masks
)
# filter.
keep
=
cate_scores
>=
cfg
.
update_thr
if
keep
.
sum
()
==
0
:
return
None
seg_preds
=
seg_preds
[
keep
,
:,
:]
cate_scores
=
cate_scores
[
keep
]
cate_labels
=
cate_labels
[
keep
]
# sort and keep top_k
sort_inds
=
torch
.
argsort
(
cate_scores
,
descending
=
True
)
if
len
(
sort_inds
)
>
cfg
.
max_per_img
:
sort_inds
=
sort_inds
[:
cfg
.
max_per_img
]
seg_preds
=
seg_preds
[
sort_inds
,
:,
:]
cate_scores
=
cate_scores
[
sort_inds
]
cate_labels
=
cate_labels
[
sort_inds
]
seg_preds
=
F
.
interpolate
(
seg_preds
.
unsqueeze
(
0
),
size
=
upsampled_size_out
,
mode
=
'bilinear'
)[:,
:,
:
h
,
:
w
]
seg_masks
=
F
.
interpolate
(
seg_preds
,
size
=
ori_shape
[:
2
],
mode
=
'bilinear'
).
squeeze
(
0
)
seg_masks
=
seg_masks
>
cfg
.
mask_thr
return
seg_masks
,
cate_labels
,
cate_scores
mmdet/models/anchor_heads/ssd_head.py
0 → 100644
View file @
57f6da5c
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
from
mmdet.core
import
AnchorGenerator
,
anchor_target
,
multi_apply
from
..losses
import
smooth_l1_loss
from
..registry
import
HEADS
from
.anchor_head
import
AnchorHead
# TODO: add loss evaluator for SSD
@
HEADS
.
register_module
class
SSDHead
(
AnchorHead
):
def
__init__
(
self
,
input_size
=
300
,
num_classes
=
81
,
in_channels
=
(
512
,
1024
,
512
,
256
,
256
,
256
),
anchor_strides
=
(
8
,
16
,
32
,
64
,
100
,
300
),
basesize_ratio_range
=
(
0.1
,
0.9
),
anchor_ratios
=
([
2
],
[
2
,
3
],
[
2
,
3
],
[
2
,
3
],
[
2
],
[
2
]),
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
super
(
AnchorHead
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
num_classes
=
num_classes
self
.
in_channels
=
in_channels
self
.
cls_out_channels
=
num_classes
num_anchors
=
[
len
(
ratios
)
*
2
+
2
for
ratios
in
anchor_ratios
]
reg_convs
=
[]
cls_convs
=
[]
for
i
in
range
(
len
(
in_channels
)):
reg_convs
.
append
(
nn
.
Conv2d
(
in_channels
[
i
],
num_anchors
[
i
]
*
4
,
kernel_size
=
3
,
padding
=
1
))
cls_convs
.
append
(
nn
.
Conv2d
(
in_channels
[
i
],
num_anchors
[
i
]
*
num_classes
,
kernel_size
=
3
,
padding
=
1
))
self
.
reg_convs
=
nn
.
ModuleList
(
reg_convs
)
self
.
cls_convs
=
nn
.
ModuleList
(
cls_convs
)
min_ratio
,
max_ratio
=
basesize_ratio_range
min_ratio
=
int
(
min_ratio
*
100
)
max_ratio
=
int
(
max_ratio
*
100
)
step
=
int
(
np
.
floor
(
max_ratio
-
min_ratio
)
/
(
len
(
in_channels
)
-
2
))
min_sizes
=
[]
max_sizes
=
[]
for
r
in
range
(
int
(
min_ratio
),
int
(
max_ratio
)
+
1
,
step
):
min_sizes
.
append
(
int
(
input_size
*
r
/
100
))
max_sizes
.
append
(
int
(
input_size
*
(
r
+
step
)
/
100
))
if
input_size
==
300
:
if
basesize_ratio_range
[
0
]
==
0.15
:
# SSD300 COCO
min_sizes
.
insert
(
0
,
int
(
input_size
*
7
/
100
))
max_sizes
.
insert
(
0
,
int
(
input_size
*
15
/
100
))
elif
basesize_ratio_range
[
0
]
==
0.2
:
# SSD300 VOC
min_sizes
.
insert
(
0
,
int
(
input_size
*
10
/
100
))
max_sizes
.
insert
(
0
,
int
(
input_size
*
20
/
100
))
elif
input_size
==
512
:
if
basesize_ratio_range
[
0
]
==
0.1
:
# SSD512 COCO
min_sizes
.
insert
(
0
,
int
(
input_size
*
4
/
100
))
max_sizes
.
insert
(
0
,
int
(
input_size
*
10
/
100
))
elif
basesize_ratio_range
[
0
]
==
0.15
:
# SSD512 VOC
min_sizes
.
insert
(
0
,
int
(
input_size
*
7
/
100
))
max_sizes
.
insert
(
0
,
int
(
input_size
*
15
/
100
))
self
.
anchor_generators
=
[]
self
.
anchor_strides
=
anchor_strides
for
k
in
range
(
len
(
anchor_strides
)):
base_size
=
min_sizes
[
k
]
stride
=
anchor_strides
[
k
]
ctr
=
((
stride
-
1
)
/
2.
,
(
stride
-
1
)
/
2.
)
scales
=
[
1.
,
np
.
sqrt
(
max_sizes
[
k
]
/
min_sizes
[
k
])]
ratios
=
[
1.
]
for
r
in
anchor_ratios
[
k
]:
ratios
+=
[
1
/
r
,
r
]
# 4 or 6 ratio
anchor_generator
=
AnchorGenerator
(
base_size
,
scales
,
ratios
,
scale_major
=
False
,
ctr
=
ctr
)
indices
=
list
(
range
(
len
(
ratios
)))
indices
.
insert
(
1
,
len
(
indices
))
anchor_generator
.
base_anchors
=
torch
.
index_select
(
anchor_generator
.
base_anchors
,
0
,
torch
.
LongTensor
(
indices
))
self
.
anchor_generators
.
append
(
anchor_generator
)
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
use_sigmoid_cls
=
False
self
.
cls_focal_loss
=
False
self
.
fp16_enabled
=
False
def
init_weights
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
xavier_init
(
m
,
distribution
=
'uniform'
,
bias
=
0
)
def
forward
(
self
,
feats
):
cls_scores
=
[]
bbox_preds
=
[]
for
feat
,
reg_conv
,
cls_conv
in
zip
(
feats
,
self
.
reg_convs
,
self
.
cls_convs
):
cls_scores
.
append
(
cls_conv
(
feat
))
bbox_preds
.
append
(
reg_conv
(
feat
))
return
cls_scores
,
bbox_preds
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_total_samples
,
cfg
):
loss_cls_all
=
F
.
cross_entropy
(
cls_score
,
labels
,
reduction
=
'none'
)
*
label_weights
pos_inds
=
(
labels
>
0
).
nonzero
().
view
(
-
1
)
neg_inds
=
(
labels
==
0
).
nonzero
().
view
(
-
1
)
num_pos_samples
=
pos_inds
.
size
(
0
)
num_neg_samples
=
cfg
.
neg_pos_ratio
*
num_pos_samples
if
num_neg_samples
>
neg_inds
.
size
(
0
):
num_neg_samples
=
neg_inds
.
size
(
0
)
topk_loss_cls_neg
,
_
=
loss_cls_all
[
neg_inds
].
topk
(
num_neg_samples
)
loss_cls_pos
=
loss_cls_all
[
pos_inds
].
sum
()
loss_cls_neg
=
topk_loss_cls_neg
.
sum
()
loss_cls
=
(
loss_cls_pos
+
loss_cls_neg
)
/
num_total_samples
loss_bbox
=
smooth_l1_loss
(
bbox_pred
,
bbox_targets
,
bbox_weights
,
beta
=
cfg
.
smoothl1_beta
,
avg_factor
=
num_total_samples
)
return
loss_cls
[
None
],
loss_bbox
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
cfg
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
device
=
cls_scores
[
0
].
device
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_metas
,
device
=
device
)
cls_reg_targets
=
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes
,
img_metas
,
self
.
target_means
,
self
.
target_stds
,
cfg
,
gt_bboxes_ignore_list
=
gt_bboxes_ignore
,
gt_labels_list
=
gt_labels
,
label_channels
=
1
,
sampling
=
False
,
unmap_outputs
=
False
)
if
cls_reg_targets
is
None
:
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
num_images
=
len
(
img_metas
)
all_cls_scores
=
torch
.
cat
([
s
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_images
,
-
1
,
self
.
cls_out_channels
)
for
s
in
cls_scores
],
1
)
all_labels
=
torch
.
cat
(
labels_list
,
-
1
).
view
(
num_images
,
-
1
)
all_label_weights
=
torch
.
cat
(
label_weights_list
,
-
1
).
view
(
num_images
,
-
1
)
all_bbox_preds
=
torch
.
cat
([
b
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_images
,
-
1
,
4
)
for
b
in
bbox_preds
],
-
2
)
all_bbox_targets
=
torch
.
cat
(
bbox_targets_list
,
-
2
).
view
(
num_images
,
-
1
,
4
)
all_bbox_weights
=
torch
.
cat
(
bbox_weights_list
,
-
2
).
view
(
num_images
,
-
1
,
4
)
# check NaN and Inf
assert
torch
.
isfinite
(
all_cls_scores
).
all
().
item
(),
\
'classification scores become infinite or NaN!'
assert
torch
.
isfinite
(
all_bbox_preds
).
all
().
item
(),
\
'bbox predications become infinite or NaN!'
losses_cls
,
losses_bbox
=
multi_apply
(
self
.
loss_single
,
all_cls_scores
,
all_bbox_preds
,
all_labels
,
all_label_weights
,
all_bbox_targets
,
all_bbox_weights
,
num_total_samples
=
num_total_pos
,
cfg
=
cfg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_bbox
=
losses_bbox
)
mmdet/models/backbones/__init__.py
0 → 100644
View file @
57f6da5c
from
.hrnet
import
HRNet
from
.resnet
import
ResNet
,
make_res_layer
from
.resnext
import
ResNeXt
from
.ssd_vgg
import
SSDVGG
__all__
=
[
'ResNet'
,
'make_res_layer'
,
'ResNeXt'
,
'SSDVGG'
,
'HRNet'
]
mmdet/models/backbones/hrnet.py
0 → 100644
View file @
57f6da5c
import
torch.nn
as
nn
from
mmcv.cnn
import
constant_init
,
kaiming_init
from
mmcv.runner
import
load_checkpoint
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmdet.utils
import
get_root_logger
from
..registry
import
BACKBONES
from
..utils
import
build_conv_layer
,
build_norm_layer
from
.resnet
import
BasicBlock
,
Bottleneck
class
HRModule
(
nn
.
Module
):
""" High-Resolution Module for HRNet. In this module, every branch
has 4 BasicBlocks/Bottlenecks. Fusion/Exchange is in this module.
"""
def
__init__
(
self
,
num_branches
,
blocks
,
num_blocks
,
in_channels
,
num_channels
,
multiscale_output
=
True
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
)):
super
(
HRModule
,
self
).
__init__
()
self
.
_check_branches
(
num_branches
,
num_blocks
,
in_channels
,
num_channels
)
self
.
in_channels
=
in_channels
self
.
num_branches
=
num_branches
self
.
multiscale_output
=
multiscale_output
self
.
norm_cfg
=
norm_cfg
self
.
conv_cfg
=
conv_cfg
self
.
with_cp
=
with_cp
self
.
branches
=
self
.
_make_branches
(
num_branches
,
blocks
,
num_blocks
,
num_channels
)
self
.
fuse_layers
=
self
.
_make_fuse_layers
()
self
.
relu
=
nn
.
ReLU
(
inplace
=
False
)
def
_check_branches
(
self
,
num_branches
,
num_blocks
,
in_channels
,
num_channels
):
if
num_branches
!=
len
(
num_blocks
):
error_msg
=
'NUM_BRANCHES({}) <> NUM_BLOCKS({})'
.
format
(
num_branches
,
len
(
num_blocks
))
raise
ValueError
(
error_msg
)
if
num_branches
!=
len
(
num_channels
):
error_msg
=
'NUM_BRANCHES({}) <> NUM_CHANNELS({})'
.
format
(
num_branches
,
len
(
num_channels
))
raise
ValueError
(
error_msg
)
if
num_branches
!=
len
(
in_channels
):
error_msg
=
'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'
.
format
(
num_branches
,
len
(
in_channels
))
raise
ValueError
(
error_msg
)
def
_make_one_branch
(
self
,
branch_index
,
block
,
num_blocks
,
num_channels
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
\
self
.
in_channels
[
branch_index
]
!=
\
num_channels
[
branch_index
]
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
self
.
in_channels
[
branch_index
],
num_channels
[
branch_index
]
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
num_channels
[
branch_index
]
*
block
.
expansion
)[
1
])
layers
=
[]
layers
.
append
(
block
(
self
.
in_channels
[
branch_index
],
num_channels
[
branch_index
],
stride
,
downsample
=
downsample
,
with_cp
=
self
.
with_cp
,
norm_cfg
=
self
.
norm_cfg
,
conv_cfg
=
self
.
conv_cfg
))
self
.
in_channels
[
branch_index
]
=
\
num_channels
[
branch_index
]
*
block
.
expansion
for
i
in
range
(
1
,
num_blocks
[
branch_index
]):
layers
.
append
(
block
(
self
.
in_channels
[
branch_index
],
num_channels
[
branch_index
],
with_cp
=
self
.
with_cp
,
norm_cfg
=
self
.
norm_cfg
,
conv_cfg
=
self
.
conv_cfg
))
return
nn
.
Sequential
(
*
layers
)
def
_make_branches
(
self
,
num_branches
,
block
,
num_blocks
,
num_channels
):
branches
=
[]
for
i
in
range
(
num_branches
):
branches
.
append
(
self
.
_make_one_branch
(
i
,
block
,
num_blocks
,
num_channels
))
return
nn
.
ModuleList
(
branches
)
def
_make_fuse_layers
(
self
):
if
self
.
num_branches
==
1
:
return
None
num_branches
=
self
.
num_branches
in_channels
=
self
.
in_channels
fuse_layers
=
[]
num_out_branches
=
num_branches
if
self
.
multiscale_output
else
1
for
i
in
range
(
num_out_branches
):
fuse_layer
=
[]
for
j
in
range
(
num_branches
):
if
j
>
i
:
fuse_layer
.
append
(
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
in_channels
[
j
],
in_channels
[
i
],
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
in_channels
[
i
])[
1
],
nn
.
Upsample
(
scale_factor
=
2
**
(
j
-
i
),
mode
=
'nearest'
)))
elif
j
==
i
:
fuse_layer
.
append
(
None
)
else
:
conv_downsamples
=
[]
for
k
in
range
(
i
-
j
):
if
k
==
i
-
j
-
1
:
conv_downsamples
.
append
(
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
in_channels
[
j
],
in_channels
[
i
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
in_channels
[
i
])[
1
]))
else
:
conv_downsamples
.
append
(
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
in_channels
[
j
],
in_channels
[
j
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
in_channels
[
j
])[
1
],
nn
.
ReLU
(
inplace
=
False
)))
fuse_layer
.
append
(
nn
.
Sequential
(
*
conv_downsamples
))
fuse_layers
.
append
(
nn
.
ModuleList
(
fuse_layer
))
return
nn
.
ModuleList
(
fuse_layers
)
def
forward
(
self
,
x
):
if
self
.
num_branches
==
1
:
return
[
self
.
branches
[
0
](
x
[
0
])]
for
i
in
range
(
self
.
num_branches
):
x
[
i
]
=
self
.
branches
[
i
](
x
[
i
])
x_fuse
=
[]
for
i
in
range
(
len
(
self
.
fuse_layers
)):
y
=
0
for
j
in
range
(
self
.
num_branches
):
if
i
==
j
:
y
+=
x
[
j
]
else
:
y
+=
self
.
fuse_layers
[
i
][
j
](
x
[
j
])
x_fuse
.
append
(
self
.
relu
(
y
))
return
x_fuse
@
BACKBONES
.
register_module
class
HRNet
(
nn
.
Module
):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
arXiv: https://arxiv.org/abs/1904.04514
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict
=
{
'BASIC'
:
BasicBlock
,
'BOTTLENECK'
:
Bottleneck
}
def
__init__
(
self
,
extra
,
in_channels
=
3
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
norm_eval
=
True
,
with_cp
=
False
,
zero_init_residual
=
False
):
super
(
HRNet
,
self
).
__init__
()
self
.
extra
=
extra
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
norm_eval
=
norm_eval
self
.
with_cp
=
with_cp
self
.
zero_init_residual
=
zero_init_residual
# stem net
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
64
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
self
.
norm_cfg
,
64
,
postfix
=
2
)
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
in_channels
,
64
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
build_conv_layer
(
self
.
conv_cfg
,
64
,
64
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
# stage 1
self
.
stage1_cfg
=
self
.
extra
[
'stage1'
]
num_channels
=
self
.
stage1_cfg
[
'num_channels'
][
0
]
block_type
=
self
.
stage1_cfg
[
'block'
]
num_blocks
=
self
.
stage1_cfg
[
'num_blocks'
][
0
]
block
=
self
.
blocks_dict
[
block_type
]
stage1_out_channels
=
num_channels
*
block
.
expansion
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
num_channels
,
num_blocks
)
# stage 2
self
.
stage2_cfg
=
self
.
extra
[
'stage2'
]
num_channels
=
self
.
stage2_cfg
[
'num_channels'
]
block_type
=
self
.
stage2_cfg
[
'block'
]
block
=
self
.
blocks_dict
[
block_type
]
num_channels
=
[
channel
*
block
.
expansion
for
channel
in
num_channels
]
self
.
transition1
=
self
.
_make_transition_layer
([
stage1_out_channels
],
num_channels
)
self
.
stage2
,
pre_stage_channels
=
self
.
_make_stage
(
self
.
stage2_cfg
,
num_channels
)
# stage 3
self
.
stage3_cfg
=
self
.
extra
[
'stage3'
]
num_channels
=
self
.
stage3_cfg
[
'num_channels'
]
block_type
=
self
.
stage3_cfg
[
'block'
]
block
=
self
.
blocks_dict
[
block_type
]
num_channels
=
[
channel
*
block
.
expansion
for
channel
in
num_channels
]
self
.
transition2
=
self
.
_make_transition_layer
(
pre_stage_channels
,
num_channels
)
self
.
stage3
,
pre_stage_channels
=
self
.
_make_stage
(
self
.
stage3_cfg
,
num_channels
)
# stage 4
self
.
stage4_cfg
=
self
.
extra
[
'stage4'
]
num_channels
=
self
.
stage4_cfg
[
'num_channels'
]
block_type
=
self
.
stage4_cfg
[
'block'
]
block
=
self
.
blocks_dict
[
block_type
]
num_channels
=
[
channel
*
block
.
expansion
for
channel
in
num_channels
]
self
.
transition3
=
self
.
_make_transition_layer
(
pre_stage_channels
,
num_channels
)
self
.
stage4
,
pre_stage_channels
=
self
.
_make_stage
(
self
.
stage4_cfg
,
num_channels
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
def
_make_transition_layer
(
self
,
num_channels_pre_layer
,
num_channels_cur_layer
):
num_branches_cur
=
len
(
num_channels_cur_layer
)
num_branches_pre
=
len
(
num_channels_pre_layer
)
transition_layers
=
[]
for
i
in
range
(
num_branches_cur
):
if
i
<
num_branches_pre
:
if
num_channels_cur_layer
[
i
]
!=
num_channels_pre_layer
[
i
]:
transition_layers
.
append
(
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
num_channels_pre_layer
[
i
],
num_channels_cur_layer
[
i
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
num_channels_cur_layer
[
i
])[
1
],
nn
.
ReLU
(
inplace
=
True
)))
else
:
transition_layers
.
append
(
None
)
else
:
conv_downsamples
=
[]
for
j
in
range
(
i
+
1
-
num_branches_pre
):
in_channels
=
num_channels_pre_layer
[
-
1
]
out_channels
=
num_channels_cur_layer
[
i
]
\
if
j
==
i
-
num_branches_pre
else
in_channels
conv_downsamples
.
append
(
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
)))
transition_layers
.
append
(
nn
.
Sequential
(
*
conv_downsamples
))
return
nn
.
ModuleList
(
transition_layers
)
def
_make_layer
(
self
,
block
,
inplanes
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
build_conv_layer
(
self
.
conv_cfg
,
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
build_norm_layer
(
self
.
norm_cfg
,
planes
*
block
.
expansion
)[
1
])
layers
=
[]
layers
.
append
(
block
(
inplanes
,
planes
,
stride
,
downsample
=
downsample
,
with_cp
=
self
.
with_cp
,
norm_cfg
=
self
.
norm_cfg
,
conv_cfg
=
self
.
conv_cfg
))
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
,
planes
,
with_cp
=
self
.
with_cp
,
norm_cfg
=
self
.
norm_cfg
,
conv_cfg
=
self
.
conv_cfg
))
return
nn
.
Sequential
(
*
layers
)
def
_make_stage
(
self
,
layer_config
,
in_channels
,
multiscale_output
=
True
):
num_modules
=
layer_config
[
'num_modules'
]
num_branches
=
layer_config
[
'num_branches'
]
num_blocks
=
layer_config
[
'num_blocks'
]
num_channels
=
layer_config
[
'num_channels'
]
block
=
self
.
blocks_dict
[
layer_config
[
'block'
]]
hr_modules
=
[]
for
i
in
range
(
num_modules
):
# multi_scale_output is only used for the last module
if
not
multiscale_output
and
i
==
num_modules
-
1
:
reset_multiscale_output
=
False
else
:
reset_multiscale_output
=
True
hr_modules
.
append
(
HRModule
(
num_branches
,
block
,
num_blocks
,
in_channels
,
num_channels
,
reset_multiscale_output
,
with_cp
=
self
.
with_cp
,
norm_cfg
=
self
.
norm_cfg
,
conv_cfg
=
self
.
conv_cfg
))
return
nn
.
Sequential
(
*
hr_modules
),
in_channels
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
constant_init
(
m
,
1
)
if
self
.
zero_init_residual
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Bottleneck
):
constant_init
(
m
.
norm3
,
0
)
elif
isinstance
(
m
,
BasicBlock
):
constant_init
(
m
.
norm2
,
0
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
norm2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
layer1
(
x
)
x_list
=
[]
for
i
in
range
(
self
.
stage2_cfg
[
'num_branches'
]):
if
self
.
transition1
[
i
]
is
not
None
:
x_list
.
append
(
self
.
transition1
[
i
](
x
))
else
:
x_list
.
append
(
x
)
y_list
=
self
.
stage2
(
x_list
)
x_list
=
[]
for
i
in
range
(
self
.
stage3_cfg
[
'num_branches'
]):
if
self
.
transition2
[
i
]
is
not
None
:
x_list
.
append
(
self
.
transition2
[
i
](
y_list
[
-
1
]))
else
:
x_list
.
append
(
y_list
[
i
])
y_list
=
self
.
stage3
(
x_list
)
x_list
=
[]
for
i
in
range
(
self
.
stage4_cfg
[
'num_branches'
]):
if
self
.
transition3
[
i
]
is
not
None
:
x_list
.
append
(
self
.
transition3
[
i
](
y_list
[
-
1
]))
else
:
x_list
.
append
(
y_list
[
i
])
y_list
=
self
.
stage4
(
x_list
)
return
y_list
def
train
(
self
,
mode
=
True
):
super
(
HRNet
,
self
).
train
(
mode
)
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
mmdet/models/backbones/resnet.py
0 → 100644
View file @
57f6da5c
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
constant_init
,
kaiming_init
from
mmcv.runner
import
load_checkpoint
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmdet.models.plugins
import
GeneralizedAttention
from
mmdet.ops
import
ContextBlock
from
mmdet.utils
import
get_root_logger
from
..registry
import
BACKBONES
from
..utils
import
build_conv_layer
,
build_norm_layer
class
BasicBlock
(
nn
.
Module
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
dcn
=
None
,
gcb
=
None
,
gen_attention
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
assert
dcn
is
None
,
"Not implemented yet."
assert
gen_attention
is
None
,
"Not implemented yet."
assert
gcb
is
None
,
"Not implemented yet."
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
planes
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
norm_cfg
,
planes
,
postfix
=
2
)
self
.
conv1
=
build_conv_layer
(
conv_cfg
,
inplanes
,
planes
,
3
,
stride
=
stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
build_conv_layer
(
conv_cfg
,
planes
,
planes
,
3
,
padding
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
dilation
=
dilation
assert
not
with_cp
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
norm2
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
class
Bottleneck
(
nn
.
Module
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
dcn
=
None
,
gcb
=
None
,
gen_attention
=
None
):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super
(
Bottleneck
,
self
).
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
assert
dcn
is
None
or
isinstance
(
dcn
,
dict
)
assert
gcb
is
None
or
isinstance
(
gcb
,
dict
)
assert
gen_attention
is
None
or
isinstance
(
gen_attention
,
dict
)
self
.
inplanes
=
inplanes
self
.
planes
=
planes
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
style
=
style
self
.
with_cp
=
with_cp
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
dcn
=
dcn
self
.
with_dcn
=
dcn
is
not
None
self
.
gcb
=
gcb
self
.
with_gcb
=
gcb
is
not
None
self
.
gen_attention
=
gen_attention
self
.
with_gen_attention
=
gen_attention
is
not
None
if
self
.
style
==
'pytorch'
:
self
.
conv1_stride
=
1
self
.
conv2_stride
=
stride
else
:
self
.
conv1_stride
=
stride
self
.
conv2_stride
=
1
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
planes
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
norm_cfg
,
planes
,
postfix
=
2
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
norm_cfg
,
planes
*
self
.
expansion
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
conv_cfg
,
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
fallback_on_stride
=
False
if
self
.
with_dcn
:
fallback_on_stride
=
dcn
.
pop
(
'fallback_on_stride'
,
False
)
if
not
self
.
with_dcn
or
fallback_on_stride
:
self
.
conv2
=
build_conv_layer
(
conv_cfg
,
planes
,
planes
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
else
:
assert
self
.
conv_cfg
is
None
,
'conv_cfg cannot be None for DCN'
self
.
conv2
=
build_conv_layer
(
dcn
,
planes
,
planes
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
conv3
=
build_conv_layer
(
conv_cfg
,
planes
,
planes
*
self
.
expansion
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
if
self
.
with_gcb
:
gcb_inplanes
=
planes
*
self
.
expansion
self
.
context_block
=
ContextBlock
(
inplanes
=
gcb_inplanes
,
**
gcb
)
# gen_attention
if
self
.
with_gen_attention
:
self
.
gen_attention_block
=
GeneralizedAttention
(
planes
,
**
gen_attention
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
@
property
def
norm3
(
self
):
return
getattr
(
self
,
self
.
norm3_name
)
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
norm2
(
out
)
out
=
self
.
relu
(
out
)
if
self
.
with_gen_attention
:
out
=
self
.
gen_attention_block
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
norm3
(
out
)
if
self
.
with_gcb
:
out
=
self
.
context_block
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
def
make_res_layer
(
block
,
inplanes
,
planes
,
blocks
,
stride
=
1
,
dilation
=
1
,
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
dcn
=
None
,
gcb
=
None
,
gen_attention
=
None
,
gen_attention_blocks
=
[]):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
build_conv_layer
(
conv_cfg
,
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
planes
*
block
.
expansion
)[
1
],
)
layers
=
[]
layers
.
append
(
block
(
inplanes
=
inplanes
,
planes
=
planes
,
stride
=
stride
,
dilation
=
dilation
,
downsample
=
downsample
,
style
=
style
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
dcn
=
dcn
,
gcb
=
gcb
,
gen_attention
=
gen_attention
if
(
0
in
gen_attention_blocks
)
else
None
))
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
=
inplanes
,
planes
=
planes
,
stride
=
1
,
dilation
=
dilation
,
style
=
style
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
dcn
=
dcn
,
gcb
=
gcb
,
gen_attention
=
gen_attention
if
(
i
in
gen_attention_blocks
)
else
None
))
return
nn
.
Sequential
(
*
layers
)
@
BACKBONES
.
register_module
class
ResNet
(
nn
.
Module
):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
depth
,
in_channels
=
3
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
style
=
'pytorch'
,
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
norm_eval
=
True
,
dcn
=
None
,
stage_with_dcn
=
(
False
,
False
,
False
,
False
),
gcb
=
None
,
stage_with_gcb
=
(
False
,
False
,
False
,
False
),
gen_attention
=
None
,
stage_with_gen_attention
=
((),
(),
(),
()),
with_cp
=
False
,
zero_init_residual
=
True
):
super
(
ResNet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
self
.
depth
=
depth
self
.
num_stages
=
num_stages
assert
num_stages
>=
1
and
num_stages
<=
4
self
.
strides
=
strides
self
.
dilations
=
dilations
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
self
.
out_indices
=
out_indices
assert
max
(
out_indices
)
<
num_stages
self
.
style
=
style
self
.
frozen_stages
=
frozen_stages
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
with_cp
=
with_cp
self
.
norm_eval
=
norm_eval
self
.
dcn
=
dcn
self
.
stage_with_dcn
=
stage_with_dcn
if
dcn
is
not
None
:
assert
len
(
stage_with_dcn
)
==
num_stages
self
.
gen_attention
=
gen_attention
self
.
gcb
=
gcb
self
.
stage_with_gcb
=
stage_with_gcb
if
gcb
is
not
None
:
assert
len
(
stage_with_gcb
)
==
num_stages
self
.
zero_init_residual
=
zero_init_residual
self
.
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
inplanes
=
64
self
.
_make_stem_layer
(
in_channels
)
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
dcn
=
self
.
dcn
if
self
.
stage_with_dcn
[
i
]
else
None
gcb
=
self
.
gcb
if
self
.
stage_with_gcb
[
i
]
else
None
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
self
.
block
,
self
.
inplanes
,
planes
,
num_blocks
,
stride
=
stride
,
dilation
=
dilation
,
style
=
self
.
style
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
dcn
=
dcn
,
gcb
=
gcb
,
gen_attention
=
gen_attention
,
gen_attention_blocks
=
stage_with_gen_attention
[
i
])
self
.
inplanes
=
planes
*
self
.
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
_freeze_stages
()
self
.
feat_dim
=
self
.
block
.
expansion
*
64
*
2
**
(
len
(
self
.
stage_blocks
)
-
1
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
def
_make_stem_layer
(
self
,
in_channels
):
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
in_channels
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
64
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
norm1
.
eval
()
for
m
in
[
self
.
conv1
,
self
.
norm1
]:
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
getattr
(
self
,
'layer{}'
.
format
(
i
))
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
constant_init
(
m
,
1
)
if
self
.
dcn
is
not
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Bottleneck
)
and
hasattr
(
m
,
'conv2_offset'
):
constant_init
(
m
.
conv2_offset
,
0
)
if
self
.
zero_init_residual
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Bottleneck
):
constant_init
(
m
.
norm3
,
0
)
elif
isinstance
(
m
,
BasicBlock
):
constant_init
(
m
.
norm2
,
0
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
outs
=
[]
for
i
,
layer_name
in
enumerate
(
self
.
res_layers
):
res_layer
=
getattr
(
self
,
layer_name
)
x
=
res_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
mmdet/models/backbones/resnext.py
0 → 100644
View file @
57f6da5c
import
math
import
torch.nn
as
nn
from
..registry
import
BACKBONES
from
..utils
import
build_conv_layer
,
build_norm_layer
from
.resnet
import
Bottleneck
as
_Bottleneck
from
.resnet
import
ResNet
class
Bottleneck
(
_Bottleneck
):
def
__init__
(
self
,
inplanes
,
planes
,
groups
=
1
,
base_width
=
4
,
**
kwargs
):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super
(
Bottleneck
,
self
).
__init__
(
inplanes
,
planes
,
**
kwargs
)
if
groups
==
1
:
width
=
self
.
planes
else
:
width
=
math
.
floor
(
self
.
planes
*
(
base_width
/
64
))
*
groups
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
width
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
self
.
norm_cfg
,
width
,
postfix
=
2
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
planes
*
self
.
expansion
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
inplanes
,
width
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
fallback_on_stride
=
False
self
.
with_modulated_dcn
=
False
if
self
.
with_dcn
:
fallback_on_stride
=
self
.
dcn
.
pop
(
'fallback_on_stride'
,
False
)
if
not
self
.
with_dcn
or
fallback_on_stride
:
self
.
conv2
=
build_conv_layer
(
self
.
conv_cfg
,
width
,
width
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias
=
False
)
else
:
assert
self
.
conv_cfg
is
None
,
'conv_cfg must be None for DCN'
self
.
conv2
=
build_conv_layer
(
self
.
dcn
,
width
,
width
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
conv3
=
build_conv_layer
(
self
.
conv_cfg
,
width
,
self
.
planes
*
self
.
expansion
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
def
make_res_layer
(
block
,
inplanes
,
planes
,
blocks
,
stride
=
1
,
dilation
=
1
,
groups
=
1
,
base_width
=
4
,
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
dcn
=
None
,
gcb
=
None
):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
build_conv_layer
(
conv_cfg
,
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
planes
*
block
.
expansion
)[
1
],
)
layers
=
[]
layers
.
append
(
block
(
inplanes
=
inplanes
,
planes
=
planes
,
stride
=
stride
,
dilation
=
dilation
,
downsample
=
downsample
,
groups
=
groups
,
base_width
=
base_width
,
style
=
style
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
dcn
=
dcn
,
gcb
=
gcb
))
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
=
inplanes
,
planes
=
planes
,
stride
=
1
,
dilation
=
dilation
,
groups
=
groups
,
base_width
=
base_width
,
style
=
style
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
dcn
=
dcn
,
gcb
=
gcb
))
return
nn
.
Sequential
(
*
layers
)
@
BACKBONES
.
register_module
class
ResNeXt
(
ResNet
):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import ResNeXt
>>> import torch
>>> self = ResNeXt(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings
=
{
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
groups
=
1
,
base_width
=
4
,
**
kwargs
):
super
(
ResNeXt
,
self
).
__init__
(
**
kwargs
)
self
.
groups
=
groups
self
.
base_width
=
base_width
self
.
inplanes
=
64
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
stride
=
self
.
strides
[
i
]
dilation
=
self
.
dilations
[
i
]
dcn
=
self
.
dcn
if
self
.
stage_with_dcn
[
i
]
else
None
gcb
=
self
.
gcb
if
self
.
stage_with_gcb
[
i
]
else
None
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
self
.
block
,
self
.
inplanes
,
planes
,
num_blocks
,
stride
=
stride
,
dilation
=
dilation
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
style
=
self
.
style
,
with_cp
=
self
.
with_cp
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
dcn
=
dcn
,
gcb
=
gcb
)
self
.
inplanes
=
planes
*
self
.
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
_freeze_stages
()
mmdet/models/backbones/ssd_vgg.py
0 → 100644
View file @
57f6da5c
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
VGG
,
constant_init
,
kaiming_init
,
normal_init
,
xavier_init
from
mmcv.runner
import
load_checkpoint
from
mmdet.utils
import
get_root_logger
from
..registry
import
BACKBONES
@
BACKBONES
.
register_module
class
SSDVGG
(
VGG
):
"""VGG Backbone network for single-shot-detection
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
out_indices (Sequence[int]): Output from which stages.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 300, 300)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 1024, 19, 19)
(1, 512, 10, 10)
(1, 256, 5, 5)
(1, 256, 3, 3)
(1, 256, 1, 1)
"""
extra_setting
=
{
300
:
(
256
,
'S'
,
512
,
128
,
'S'
,
256
,
128
,
256
,
128
,
256
),
512
:
(
256
,
'S'
,
512
,
128
,
'S'
,
256
,
128
,
'S'
,
256
,
128
,
'S'
,
256
,
128
),
}
def
__init__
(
self
,
input_size
,
depth
,
with_last_pool
=
False
,
ceil_mode
=
True
,
out_indices
=
(
3
,
4
),
out_feature_indices
=
(
22
,
34
),
l2_norm_scale
=
20.
):
# TODO: in_channels for mmcv.VGG
super
(
SSDVGG
,
self
).
__init__
(
depth
,
with_last_pool
=
with_last_pool
,
ceil_mode
=
ceil_mode
,
out_indices
=
out_indices
)
assert
input_size
in
(
300
,
512
)
self
.
input_size
=
input_size
self
.
features
.
add_module
(
str
(
len
(
self
.
features
)),
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
))
self
.
features
.
add_module
(
str
(
len
(
self
.
features
)),
nn
.
Conv2d
(
512
,
1024
,
kernel_size
=
3
,
padding
=
6
,
dilation
=
6
))
self
.
features
.
add_module
(
str
(
len
(
self
.
features
)),
nn
.
ReLU
(
inplace
=
True
))
self
.
features
.
add_module
(
str
(
len
(
self
.
features
)),
nn
.
Conv2d
(
1024
,
1024
,
kernel_size
=
1
))
self
.
features
.
add_module
(
str
(
len
(
self
.
features
)),
nn
.
ReLU
(
inplace
=
True
))
self
.
out_feature_indices
=
out_feature_indices
self
.
inplanes
=
1024
self
.
extra
=
self
.
_make_extra_layers
(
self
.
extra_setting
[
input_size
])
self
.
l2_norm
=
L2Norm
(
self
.
features
[
out_feature_indices
[
0
]
-
1
].
out_channels
,
l2_norm_scale
)
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
features
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
constant_init
(
m
,
1
)
elif
isinstance
(
m
,
nn
.
Linear
):
normal_init
(
m
,
std
=
0.01
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
for
m
in
self
.
extra
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
xavier_init
(
m
,
distribution
=
'uniform'
)
constant_init
(
self
.
l2_norm
,
self
.
l2_norm
.
scale
)
def
forward
(
self
,
x
):
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
features
):
x
=
layer
(
x
)
if
i
in
self
.
out_feature_indices
:
outs
.
append
(
x
)
for
i
,
layer
in
enumerate
(
self
.
extra
):
x
=
F
.
relu
(
layer
(
x
),
inplace
=
True
)
if
i
%
2
==
1
:
outs
.
append
(
x
)
outs
[
0
]
=
self
.
l2_norm
(
outs
[
0
])
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
_make_extra_layers
(
self
,
outplanes
):
layers
=
[]
kernel_sizes
=
(
1
,
3
)
num_layers
=
0
outplane
=
None
for
i
in
range
(
len
(
outplanes
)):
if
self
.
inplanes
==
'S'
:
self
.
inplanes
=
outplane
continue
k
=
kernel_sizes
[
num_layers
%
2
]
if
outplanes
[
i
]
==
'S'
:
outplane
=
outplanes
[
i
+
1
]
conv
=
nn
.
Conv2d
(
self
.
inplanes
,
outplane
,
k
,
stride
=
2
,
padding
=
1
)
else
:
outplane
=
outplanes
[
i
]
conv
=
nn
.
Conv2d
(
self
.
inplanes
,
outplane
,
k
,
stride
=
1
,
padding
=
0
)
layers
.
append
(
conv
)
self
.
inplanes
=
outplanes
[
i
]
num_layers
+=
1
if
self
.
input_size
==
512
:
layers
.
append
(
nn
.
Conv2d
(
self
.
inplanes
,
256
,
4
,
padding
=
1
))
return
nn
.
Sequential
(
*
layers
)
class
L2Norm
(
nn
.
Module
):
def
__init__
(
self
,
n_dims
,
scale
=
20.
,
eps
=
1e-10
):
super
(
L2Norm
,
self
).
__init__
()
self
.
n_dims
=
n_dims
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_dims
))
self
.
eps
=
eps
self
.
scale
=
scale
def
forward
(
self
,
x
):
# normalization layer convert to FP32 in FP16 training
x_float
=
x
.
float
()
norm
=
x_float
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
).
sqrt
()
+
self
.
eps
return
(
self
.
weight
[
None
,
:,
None
,
None
].
float
().
expand_as
(
x_float
)
*
x_float
/
norm
).
type_as
(
x
)
mmdet/models/bbox_heads/__init__.py
0 → 100644
View file @
57f6da5c
from
.bbox_head
import
BBoxHead
from
.convfc_bbox_head
import
ConvFCBBoxHead
,
SharedFCBBoxHead
from
.double_bbox_head
import
DoubleConvFCBBoxHead
__all__
=
[
'BBoxHead'
,
'ConvFCBBoxHead'
,
'SharedFCBBoxHead'
,
'DoubleConvFCBBoxHead'
]
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment