Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
f2cfa86b
Unverified
Commit
f2cfa86b
authored
Feb 09, 2019
by
Kai Chen
Committed by
GitHub
Feb 09, 2019
Browse files
Merge pull request #310 from hellock/master
Bug fix for retinanet with 2 classes (fg/bg)
parents
ba73bcc5
d1cf5e59
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
4 deletions
+3
-4
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+1
-2
mmdet/core/loss/losses.py
mmdet/core/loss/losses.py
+2
-2
No files found.
mmdet/core/anchor/anchor_target.py
View file @
f2cfa86b
...
@@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,
...
@@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,
def
expand_binary_labels
(
labels
,
label_weights
,
label_channels
):
def
expand_binary_labels
(
labels
,
label_weights
,
label_channels
):
bin_labels
=
labels
.
new_full
(
bin_labels
=
labels
.
new_full
((
labels
.
size
(
0
),
label_channels
),
0
)
(
labels
.
size
(
0
),
label_channels
),
0
,
dtype
=
torch
.
float32
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
...
...
mmdet/core/loss/losses.py
View file @
f2cfa86b
...
@@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
...
@@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
,
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
,
reduce
=
True
):
reduce
=
True
):
if
avg_factor
is
None
:
if
avg_factor
is
None
:
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
...
@@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
...
@@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
alpha
=
0.25
,
alpha
=
0.25
,
reduction
=
'mean'
):
reduction
=
'mean'
):
pred_sigmoid
=
pred
.
sigmoid
()
pred_sigmoid
=
pred
.
sigmoid
()
target
=
target
.
type_as
(
pred
)
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
weight
weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
weight
weight
=
weight
*
pt
.
pow
(
gamma
)
weight
=
weight
*
pt
.
pow
(
gamma
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment