Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
20e75c22
Commit
20e75c22
authored
Oct 07, 2018
by
Kai Chen
Browse files
use anchor_target in RetinaNet
parent
45af4242
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
87 additions
and
205 deletions
+87
-205
mmdet/core/__init__.py
mmdet/core/__init__.py
+0
-1
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+64
-21
mmdet/core/targets/__init__.py
mmdet/core/targets/__init__.py
+0
-1
mmdet/core/targets/retina_target.py
mmdet/core/targets/retina_target.py
+0
-168
mmdet/models/rpn_heads/rpn_head.py
mmdet/models/rpn_heads/rpn_head.py
+2
-2
mmdet/models/single_stage_heads/retina_head.py
mmdet/models/single_stage_heads/retina_head.py
+18
-10
mmdet/models/utils/__init__.py
mmdet/models/utils/__init__.py
+3
-2
No files found.
mmdet/core/__init__.py
View file @
20e75c22
from
.anchor
import
*
# noqa: F401, F403
from
.anchor
import
*
# noqa: F401, F403
from
.bbox_ops
import
*
# noqa: F401, F403
from
.bbox_ops
import
*
# noqa: F401, F403
from
.mask_ops
import
*
# noqa: F401, F403
from
.mask_ops
import
*
# noqa: F401, F403
from
.targets
import
*
# noqa: F401, F403
from
.losses
import
*
# noqa: F401, F403
from
.losses
import
*
# noqa: F401, F403
from
.eval
import
*
# noqa: F401, F403
from
.eval
import
*
# noqa: F401, F403
from
.parallel
import
*
# noqa: F401, F403
from
.parallel
import
*
# noqa: F401, F403
...
...
mmdet/core/anchor/anchor_target.py
View file @
20e75c22
...
@@ -4,8 +4,16 @@ from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling
...
@@ -4,8 +4,16 @@ from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling
from
..utils
import
multi_apply
from
..utils
import
multi_apply
def
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
img_metas
,
def
anchor_target
(
anchor_list
,
target_means
,
target_stds
,
cfg
):
valid_flag_list
,
gt_bboxes_list
,
img_metas
,
target_means
,
target_stds
,
cfg
,
gt_labels_list
=
None
,
cls_out_channels
=
1
,
sampling
=
True
):
"""Compute regression and classification targets for anchors.
"""Compute regression and classification targets for anchors.
Args:
Args:
...
@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
...
@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
# compute targets for each image
# compute targets for each image
means_replicas
=
[
target_means
for
_
in
range
(
num_imgs
)]
if
gt_labels_list
is
None
:
stds_replicas
=
[
target_stds
for
_
in
range
(
num_imgs
)]
gt_labels_list
=
[
None
for
_
in
range
(
num_imgs
)]
cfg_replicas
=
[
cfg
for
_
in
range
(
num_imgs
)]
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
all_bbox_weights
,
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
all_bbox_weights
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
anchor_target_single
,
anchor_target_single
,
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
anchor_list
,
img_metas
,
means_replicas
,
stds_replicas
,
cfg_replicas
)
valid_flag_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
target_means
=
target_means
,
target_stds
=
target_stds
,
cfg
=
cfg
,
cls_out_channels
=
cls_out_channels
,
sampling
=
sampling
)
# no valid anchors
# no valid anchors
if
any
([
labels
is
None
for
labels
in
all_labels
]):
if
any
([
labels
is
None
for
labels
in
all_labels
]):
return
None
return
None
# sampled anchors of all images
# sampled anchors of all images
num_total_samples
=
sum
([
num_total_pos
=
sum
([
max
(
inds
.
numel
(),
1
)
for
inds
in
pos_inds_list
])
max
(
pos_inds
.
numel
()
+
neg_inds
.
numel
(),
1
)
num_total_neg
=
sum
([
max
(
inds
.
numel
(),
1
)
for
inds
in
neg_inds_list
])
for
pos_inds
,
neg_inds
in
zip
(
pos_inds_list
,
neg_inds_list
)
])
# split targets to a list w.r.t. multiple levels
# split targets to a list w.r.t. multiple levels
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_
samples
)
bbox_weights_list
,
num_total_
pos
,
num_total_neg
)
def
images_to_levels
(
target
,
num_level_anchors
):
def
images_to_levels
(
target
,
num_level_anchors
):
...
@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
...
@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
return
level_targets
return
level_targets
def
anchor_target_single
(
flat_anchors
,
valid_flags
,
gt_bboxes
,
img_meta
,
def
anchor_target_single
(
flat_anchors
,
target_means
,
target_stds
,
cfg
):
valid_flags
,
gt_bboxes
,
gt_labels
,
img_meta
,
target_means
,
target_stds
,
cfg
,
cls_out_channels
=
1
,
sampling
=
True
):
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_meta
[
'img_shape'
][:
2
],
img_meta
[
'img_shape'
][:
2
],
cfg
.
allowed_border
)
cfg
.
allowed_border
)
...
@@ -86,10 +108,14 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -86,10 +108,14 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
pos_iou_thr
=
cfg
.
pos_iou_thr
,
pos_iou_thr
=
cfg
.
pos_iou_thr
,
neg_iou_thr
=
cfg
.
neg_iou_thr
,
neg_iou_thr
=
cfg
.
neg_iou_thr
,
min_pos_iou
=
cfg
.
min_pos_iou
)
min_pos_iou
=
cfg
.
min_pos_iou
)
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
anchor_batch_size
,
if
sampling
:
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
pos_inds
,
neg_inds
=
bbox_sampling
(
cfg
.
pos_balance_sampling
,
max_overlaps
,
assigned_gt_inds
,
cfg
.
anchor_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_balance_thr
)
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
else
:
pos_inds
=
torch
.
nonzero
(
assigned_gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
assigned_gt_inds
==
0
).
squeeze
(
-
1
).
unique
()
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
...
@@ -103,7 +129,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -103,7 +129,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_stds
)
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
1
if
gt_labels
is
None
:
labels
[
pos_inds
]
=
1
else
:
labels
[
pos_inds
]
=
gt_labels
[
assigned_gt_inds
[
pos_inds
]
-
1
]
if
cfg
.
pos_weight
<=
0
:
if
cfg
.
pos_weight
<=
0
:
label_weights
[
pos_inds
]
=
1.0
label_weights
[
pos_inds
]
=
1.0
else
:
else
:
...
@@ -115,6 +144,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -115,6 +144,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
num_total_anchors
=
flat_anchors
.
size
(
0
)
num_total_anchors
=
flat_anchors
.
size
(
0
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
if
cls_out_channels
>
1
:
labels
,
label_weights
=
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
...
@@ -122,6 +154,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -122,6 +154,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
neg_inds
)
neg_inds
)
def
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
):
bin_labels
=
labels
.
new_full
(
(
labels
.
size
(
0
),
cls_out_channels
),
0
,
dtype
=
torch
.
float32
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
cls_out_channels
)
return
bin_labels
,
bin_label_weights
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
allowed_border
=
0
):
allowed_border
=
0
):
img_h
,
img_w
=
img_shape
[:
2
]
img_h
,
img_w
=
img_shape
[:
2
]
...
...
mmdet/core/targets/__init__.py
deleted
100644 → 0
View file @
45af4242
from
.retina_target
import
retina_target
mmdet/core/targets/retina_target.py
deleted
100644 → 0
View file @
45af4242
import
torch
from
..bbox_ops
import
bbox_assign
,
bbox2delta
from
..utils
import
multi_apply
def
retina_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
target_means
,
target_stds
,
cls_out_channels
,
cfg
):
"""Compute regression and classification targets for anchors.
Args:
anchor_list (list[list]): Multi level anchors of each image.
valid_flag_list (list[list]): Multi level valid flags of each image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
target_means (Iterable): Mean value of regression targets.
target_stds (Iterable): Std value of regression targets.
cfg (dict): RPN train configs.
Returns:
tuple
"""
num_imgs
=
len
(
img_metas
)
assert
len
(
anchor_list
)
==
len
(
valid_flag_list
)
==
num_imgs
# anchor number of multi levels
num_level_anchors
=
[
anchors
.
size
(
0
)
for
anchors
in
anchor_list
[
0
]]
# concat all level anchors and flags to a single tensor
for
i
in
range
(
num_imgs
):
assert
len
(
anchor_list
[
i
])
==
len
(
valid_flag_list
[
i
])
anchor_list
[
i
]
=
torch
.
cat
(
anchor_list
[
i
])
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
# compute targets for each image
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
all_bbox_weights
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
retina_target_single
,
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
target_means
=
target_means
,
target_stds
=
target_stds
,
cls_out_channels
=
cls_out_channels
,
cfg
=
cfg
)
# no valid anchors
if
any
([
labels
is
None
for
labels
in
all_labels
]):
return
None
# sampled anchors of all images
num_pos_samples
=
sum
([
max
(
pos_inds
.
numel
(),
1
)
for
pos_inds
,
neg_inds
in
zip
(
pos_inds_list
,
neg_inds_list
)
])
# split targets to a list w.r.t. multiple levels
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_pos_samples
)
def
images_to_levels
(
target
,
num_level_anchors
):
"""Convert targets by image to targets by feature level.
[target_img0, target_img1] -> [target_level0, target_level1, ...]
"""
target
=
torch
.
stack
(
target
,
0
)
level_targets
=
[]
start
=
0
for
n
in
num_level_anchors
:
end
=
start
+
n
level_targets
.
append
(
target
[:,
start
:
end
].
squeeze
(
0
))
start
=
end
return
level_targets
def
retina_target_single
(
flat_anchors
,
valid_flags
,
gt_bboxes
,
gt_labels
,
img_meta
,
target_means
,
target_stds
,
cls_out_channels
,
cfg
):
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_meta
[
'img_shape'
][:
2
],
cfg
.
allowed_border
)
if
not
inside_flags
.
any
():
return
(
None
,
)
*
6
# assign gt and sample anchors
anchors
=
flat_anchors
[
inside_flags
,
:]
assigned_gt_inds
,
argmax_overlaps
,
max_overlaps
=
bbox_assign
(
anchors
,
gt_bboxes
,
pos_iou_thr
=
cfg
.
pos_iou_thr
,
neg_iou_thr
=
cfg
.
neg_iou_thr
,
min_pos_iou
=
cfg
.
min_pos_iou
)
pos_inds
=
torch
.
nonzero
(
assigned_gt_inds
>
0
)
neg_inds
=
torch
.
nonzero
(
assigned_gt_inds
==
0
)
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
labels
=
torch
.
zeros_like
(
assigned_gt_inds
)
label_weights
=
torch
.
zeros_like
(
assigned_gt_inds
,
dtype
=
anchors
.
dtype
)
if
len
(
pos_inds
)
>
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
).
unique
()
pos_anchors
=
anchors
[
pos_inds
,
:]
pos_gt_bbox
=
gt_bboxes
[
assigned_gt_inds
[
pos_inds
]
-
1
,
:]
pos_bbox_targets
=
bbox2delta
(
pos_anchors
,
pos_gt_bbox
,
target_means
,
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
gt_labels
[
assigned_gt_inds
[
pos_inds
]
-
1
]
if
cfg
.
pos_weight
<=
0
:
label_weights
[
pos_inds
]
=
1.0
else
:
label_weights
[
pos_inds
]
=
cfg
.
pos_weight
if
len
(
neg_inds
)
>
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
).
unique
()
label_weights
[
neg_inds
]
=
1.0
# map up to original set of anchors
num_total_anchors
=
flat_anchors
.
size
(
0
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
labels
,
label_weights
=
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
return
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pos_inds
,
neg_inds
)
def
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
):
bin_labels
=
labels
.
new_full
(
(
labels
.
size
(
0
),
cls_out_channels
),
0
,
dtype
=
torch
.
float32
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
cls_out_channels
)
return
bin_labels
,
bin_label_weights
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
allowed_border
=
0
):
img_h
,
img_w
=
img_shape
[:
2
]
if
allowed_border
>=
0
:
inside_flags
=
valid_flags
&
\
(
flat_anchors
[:,
0
]
>=
-
allowed_border
)
&
\
(
flat_anchors
[:,
1
]
>=
-
allowed_border
)
&
\
(
flat_anchors
[:,
2
]
<
img_w
+
allowed_border
)
&
\
(
flat_anchors
[:,
3
]
<
img_h
+
allowed_border
)
else
:
inside_flags
=
valid_flags
return
inside_flags
def
unmap
(
data
,
count
,
inds
,
fill
=
0
):
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
if
data
.
dim
()
==
1
:
ret
=
data
.
new_full
((
count
,
),
fill
)
ret
[
inds
]
=
data
else
:
new_size
=
(
count
,
)
+
data
.
size
()[
1
:]
ret
=
data
.
new_full
(
new_size
,
fill
)
ret
[
inds
,
:]
=
data
return
ret
mmdet/models/rpn_heads/rpn_head.py
View file @
20e75c22
...
@@ -160,7 +160,7 @@ class RPNHead(nn.Module):
...
@@ -160,7 +160,7 @@ class RPNHead(nn.Module):
if
cls_reg_targets
is
None
:
if
cls_reg_targets
is
None
:
return
None
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_
samples
)
=
cls_reg_targets
num_total_
pos
,
num_total_neg
)
=
cls_reg_targets
losses_cls
,
losses_reg
=
multi_apply
(
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
self
.
loss_single
,
rpn_cls_scores
,
rpn_cls_scores
,
...
@@ -169,7 +169,7 @@ class RPNHead(nn.Module):
...
@@ -169,7 +169,7 @@ class RPNHead(nn.Module):
label_weights_list
,
label_weights_list
,
bbox_targets_list
,
bbox_targets_list
,
bbox_weights_list
,
bbox_weights_list
,
num_total_samples
=
num_total_
samples
,
num_total_samples
=
num_total_
pos
+
num_total_neg
,
cfg
=
cfg
)
cfg
=
cfg
)
return
dict
(
loss_rpn_cls
=
losses_cls
,
loss_rpn_reg
=
losses_reg
)
return
dict
(
loss_rpn_cls
=
losses_cls
,
loss_rpn_reg
=
losses_reg
)
...
...
mmdet/models/single_stage_heads/retina_head.py
View file @
20e75c22
...
@@ -4,9 +4,9 @@ import numpy as np
...
@@ -4,9 +4,9 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmdet.core
import
(
AnchorGenerator
,
multi_apply
,
delta2bbox
,
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
multi_apply
,
weighted_smoothl1
,
weighted_sigmoid_focal_loss
,
delta2bbox
,
weighted_smoothl1
,
multiclass_nms
,
retina_target
)
weighted_sigmoid_focal_loss
,
multiclass_nms
)
from
..utils
import
normal_init
,
bias_init_with_prob
from
..utils
import
normal_init
,
bias_init_with_prob
...
@@ -172,20 +172,28 @@ class RetinaHead(nn.Module):
...
@@ -172,20 +172,28 @@ class RetinaHead(nn.Module):
avg_factor
=
num_pos_samples
)
avg_factor
=
num_pos_samples
)
return
loss_cls
,
loss_reg
return
loss_cls
,
loss_reg
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_
shape
s
,
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_
meta
s
,
cfg
):
cfg
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_shapes
)
featmap_sizes
,
img_metas
)
cls_reg_targets
=
retina_target
(
cls_reg_targets
=
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes
,
gt_labels
,
img_shapes
,
anchor_list
,
self
.
target_means
,
self
.
target_stds
,
self
.
cls_out_channels
,
cfg
)
valid_flag_list
,
gt_bboxes
,
img_metas
,
self
.
target_means
,
self
.
target_stds
,
cfg
,
gt_labels_list
=
gt_labels
,
cls_out_channels
=
self
.
cls_out_channels
,
sampling
=
False
)
if
cls_reg_targets
is
None
:
if
cls_reg_targets
is
None
:
return
None
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_
pos_samples
)
=
cls_reg_targets
num_
total_pos
,
num_total_neg
)
=
cls_reg_targets
losses_cls
,
losses_reg
=
multi_apply
(
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
self
.
loss_single
,
...
@@ -195,7 +203,7 @@ class RetinaHead(nn.Module):
...
@@ -195,7 +203,7 @@ class RetinaHead(nn.Module):
label_weights_list
,
label_weights_list
,
bbox_targets_list
,
bbox_targets_list
,
bbox_weights_list
,
bbox_weights_list
,
num_pos_samples
=
num_
pos_sample
s
,
num_pos_samples
=
num_
total_po
s
,
cfg
=
cfg
)
cfg
=
cfg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
...
...
mmdet/models/utils/__init__.py
View file @
20e75c22
from
.conv_module
import
ConvModule
from
.conv_module
import
ConvModule
from
.norm
import
build_norm_layer
from
.norm
import
build_norm_layer
from
.weight_init
import
xavier_init
,
normal_init
,
uniform_init
,
kaiming_init
from
.weight_init
import
(
xavier_init
,
normal_init
,
uniform_init
,
kaiming_init
,
bias_init_with_prob
)
__all__
=
[
__all__
=
[
'ConvModule'
,
'build_norm_layer'
,
'xavier_init'
,
'normal_init'
,
'ConvModule'
,
'build_norm_layer'
,
'xavier_init'
,
'normal_init'
,
'uniform_init'
,
'kaiming_init'
'uniform_init'
,
'kaiming_init'
,
'bias_init_with_prob'
]
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment