Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
d5d25be2
Commit
d5d25be2
authored
Apr 15, 2019
by
Cao Yuhang
Browse files
fix bug in two stage and weighted_binary_cross_entropy_loss
parent
d67a2e16
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
10 deletions
+12
-10
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+0
-10
mmdet/core/loss/losses.py
mmdet/core/loss/losses.py
+12
-0
No files found.
mmdet/core/anchor/anchor_target.py
View file @
d5d25be2
...
...
@@ -159,16 +159,6 @@ def anchor_target_single(flat_anchors,
neg_inds
)
def
expand_binary_labels
(
labels
,
label_weights
,
label_channels
):
bin_labels
=
labels
.
new_full
((
labels
.
size
(
0
),
label_channels
),
0
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
label_channels
)
return
bin_labels
,
bin_label_weights
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
allowed_border
=
0
):
img_h
,
img_w
=
img_shape
[:
2
]
...
...
mmdet/core/loss/losses.py
View file @
d5d25be2
...
...
@@ -23,6 +23,8 @@ def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
def
weighted_binary_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
if
pred
.
dim
()
!=
label
.
dim
():
label
,
weight
=
_expand_binary_labels
(
label
,
weight
,
pred
.
size
(
-
1
))
if
avg_factor
is
None
:
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
return
F
.
binary_cross_entropy_with_logits
(
...
...
@@ -115,3 +117,13 @@ def accuracy(pred, target, topk=1):
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
)
res
.
append
(
correct_k
.
mul_
(
100.0
/
pred
.
size
(
0
)))
return
res
[
0
]
if
return_single
else
res
def
_expand_binary_labels
(
labels
,
label_weights
,
label_channels
):
bin_labels
=
labels
.
new_full
((
labels
.
size
(
0
),
label_channels
),
0
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
label_channels
)
return
bin_labels
,
bin_label_weights
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment