Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
2df1e0a0
Commit
2df1e0a0
authored
Jan 13, 2019
by
Kai Chen
Browse files
bug fix for using softmax
parent
70700512
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
28 deletions
+31
-28
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+9
-9
mmdet/models/anchor_heads/anchor_head.py
mmdet/models/anchor_heads/anchor_head.py
+18
-17
mmdet/models/anchor_heads/rpn_head.py
mmdet/models/anchor_heads/rpn_head.py
+4
-2
No files found.
mmdet/core/anchor/anchor_target.py
View file @
2df1e0a0
...
...
@@ -12,7 +12,7 @@ def anchor_target(anchor_list,
target_stds
,
cfg
,
gt_labels_list
=
None
,
cls_out
_channels
=
1
,
label
_channels
=
1
,
sampling
=
True
,
unmap_outputs
=
True
):
"""Compute regression and classification targets for anchors.
...
...
@@ -54,7 +54,7 @@ def anchor_target(anchor_list,
target_means
=
target_means
,
target_stds
=
target_stds
,
cfg
=
cfg
,
cls_out
_channels
=
cls_out
_channels
,
label
_channels
=
label
_channels
,
sampling
=
sampling
,
unmap_outputs
=
unmap_outputs
)
# no valid anchors
...
...
@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors,
target_means
,
target_stds
,
cfg
,
cls_out
_channels
=
1
,
label
_channels
=
1
,
sampling
=
True
,
unmap_outputs
=
True
):
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
...
...
@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors,
num_total_anchors
=
flat_anchors
.
size
(
0
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
if
cls_out
_channels
>
1
:
labels
,
label_weights
=
expand_binary_labels
(
labels
,
label_weights
,
cls_out
_channels
)
if
label
_channels
>
1
:
labels
,
label_weights
=
expand_binary_labels
(
labels
,
label_weights
,
label
_channels
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
...
...
@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors,
neg_inds
)
def
expand_binary_labels
(
labels
,
label_weights
,
cls_out
_channels
):
def
expand_binary_labels
(
labels
,
label_weights
,
label
_channels
):
bin_labels
=
labels
.
new_full
(
(
labels
.
size
(
0
),
cls_out
_channels
),
0
,
dtype
=
torch
.
float32
)
(
labels
.
size
(
0
),
label
_channels
),
0
,
dtype
=
torch
.
float32
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
cls_out
_channels
)
label_weights
.
size
(
0
),
label
_channels
)
return
bin_labels
,
bin_label_weights
...
...
mmdet/models/anchor_heads/anchor_head.py
View file @
2df1e0a0
...
...
@@ -14,13 +14,9 @@ from ..utils import normal_init
class
AnchorHead
(
nn
.
Module
):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
/ - conv_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - conv_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels f
or
the
RPN
feature map.
feat_channels (int): Number of channels
o
f the feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
...
...
@@ -29,6 +25,7 @@ class AnchorHead(nn.Module):
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification.
"""
# noqa: W605
def
__init__
(
self
,
...
...
@@ -80,9 +77,9 @@ class AnchorHead(nn.Module):
normal_init
(
self
.
conv_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
rpn_
cls_score
=
self
.
conv_cls
(
x
)
rpn_
bbox_pred
=
self
.
conv_reg
(
x
)
return
rpn_
cls_score
,
rpn_
bbox_pred
cls_score
=
self
.
conv_cls
(
x
)
bbox_pred
=
self
.
conv_reg
(
x
)
return
cls_score
,
bbox_pred
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
...
...
@@ -129,10 +126,13 @@ class AnchorHead(nn.Module):
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_total_samples
,
cfg
):
# classification loss
labels
=
labels
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
label_weights
=
label_weights
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
cls_score
=
cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
if
self
.
use_sigmoid_cls
:
labels
=
labels
.
reshape
(
-
1
,
self
.
cls_out_channels
)
label_weights
=
label_weights
.
reshape
(
-
1
,
self
.
cls_out_channels
)
else
:
labels
=
labels
.
reshape
(
-
1
)
label_weights
=
label_weights
.
reshape
(
-
1
)
cls_score
=
cls_score
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
self
.
cls_out_channels
)
if
self
.
use_sigmoid_cls
:
if
self
.
use_focal_loss
:
...
...
@@ -156,9 +156,9 @@ class AnchorHead(nn.Module):
loss_cls
=
cls_criterion
(
cls_score
,
labels
,
label_weights
,
avg_factor
=
num_total_samples
)
# regression loss
bbox_targets
=
bbox_targets
.
contiguous
().
view
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
contiguous
().
view
(
-
1
,
4
)
bbox_pred
=
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
4
)
bbox_targets
=
bbox_targets
.
reshape
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
reshape
(
-
1
,
4
)
bbox_pred
=
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
4
)
loss_reg
=
weighted_smoothl1
(
bbox_pred
,
bbox_targets
,
...
...
@@ -175,6 +175,7 @@ class AnchorHead(nn.Module):
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_metas
)
sampling
=
False
if
self
.
use_focal_loss
else
True
label_channels
=
self
.
cls_out_channels
if
self
.
use_sigmoid_cls
else
1
cls_reg_targets
=
anchor_target
(
anchor_list
,
valid_flag_list
,
...
...
@@ -184,7 +185,7 @@ class AnchorHead(nn.Module):
self
.
target_stds
,
cfg
,
gt_labels_list
=
gt_labels
,
cls_out
_channels
=
self
.
cls_out
_channels
,
label
_channels
=
label
_channels
,
sampling
=
sampling
)
if
cls_reg_targets
is
None
:
return
None
...
...
@@ -202,7 +203,7 @@ class AnchorHead(nn.Module):
bbox_weights_list
,
num_total_samples
=
num_total_samples
,
cfg
=
cfg
)
return
dict
(
loss_
rpn_
cls
=
losses_cls
,
loss_
rpn_
reg
=
losses_reg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
cfg
,
rescale
=
False
):
...
...
mmdet/models/anchor_heads/rpn_head.py
View file @
2df1e0a0
...
...
@@ -33,8 +33,10 @@ class RPNHead(AnchorHead):
return
rpn_cls_score
,
rpn_bbox_pred
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
img_metas
,
cfg
):
return
super
(
RPNHead
,
self
).
loss
(
cls_scores
,
bbox_preds
,
gt_bboxes
,
None
,
img_metas
,
cfg
)
losses
=
super
(
RPNHead
,
self
).
loss
(
cls_scores
,
bbox_preds
,
gt_bboxes
,
None
,
img_metas
,
cfg
)
return
dict
(
loss_rpn_cls
=
losses
[
'loss_cls'
],
loss_rpn_reg
=
losses
[
'loss_reg'
])
def
get_bboxes_single
(
self
,
cls_scores
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment