Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
d0489cf6
Commit
d0489cf6
authored
Jul 06, 2020
by
Shaoshuai Shi
Browse files
support WeightedL1Loss in anchor_head_template
parent
04a73eda
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
1 deletion
+45
-1
pcdet/models/dense_heads/anchor_head_template.py
pcdet/models/dense_heads/anchor_head_template.py
+3
-1
pcdet/utils/loss_utils.py
pcdet/utils/loss_utils.py
+42
-0
No files found.
pcdet/models/dense_heads/anchor_head_template.py
View file @
d0489cf6
...
...
@@ -74,9 +74,11 @@ class AnchorHeadTemplate(nn.Module):
'cls_loss_func'
,
loss_utils
.
SigmoidFocalClassificationLoss
(
alpha
=
0.25
,
gamma
=
2.0
)
)
reg_loss_name
=
'WeightedSmoothL1Loss'
if
losses_cfg
.
get
(
'REG_LOSS_TYPE'
,
None
)
is
None
\
else
losses_cfg
.
REG_LOSS_TYPE
self
.
add_module
(
'reg_loss_func'
,
loss_utils
.
WeightedSmoothL1Loss
(
code_weights
=
losses_cfg
.
LOSS_WEIGHTS
[
'code_weights'
])
getattr
(
loss_utils
,
reg_loss_name
)
(
code_weights
=
losses_cfg
.
LOSS_WEIGHTS
[
'code_weights'
])
)
self
.
add_module
(
'dir_loss_func'
,
...
...
pcdet/utils/loss_utils.py
View file @
d0489cf6
...
...
@@ -135,6 +135,48 @@ class WeightedSmoothL1Loss(nn.Module):
return
loss
class
WeightedL1Loss
(
nn
.
Module
):
def
__init__
(
self
,
code_weights
:
list
=
None
):
"""
Args:
code_weights: (#codes) float list if not None.
Code-wise weights.
"""
super
(
WeightedL1Loss
,
self
).
__init__
()
if
code_weights
is
not
None
:
self
.
code_weights
=
np
.
array
(
code_weights
,
dtype
=
np
.
float32
)
self
.
code_weights
=
torch
.
from_numpy
(
self
.
code_weights
).
cuda
()
def
forward
(
self
,
input
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
=
None
):
"""
Args:
input: (B, #anchors, #codes) float tensor.
Ecoded predicted locations of objects.
target: (B, #anchors, #codes) float tensor.
Regression targets.
weights: (B, #anchors) float tensor if not None.
Returns:
loss: (B, #anchors) float tensor.
Weighted smooth l1 loss without reduction.
"""
target
=
torch
.
where
(
torch
.
isnan
(
target
),
input
,
target
)
# ignore nan targets
diff
=
input
-
target
# code-wise weighting
if
self
.
code_weights
is
not
None
:
diff
=
diff
*
self
.
code_weights
.
view
(
1
,
1
,
-
1
)
loss
=
torch
.
abs
(
diff
)
# anchor-wise weighting
if
weights
is
not
None
:
assert
weights
.
shape
[
0
]
==
loss
.
shape
[
0
]
and
weights
.
shape
[
1
]
==
loss
.
shape
[
1
]
loss
=
loss
*
weights
.
unsqueeze
(
-
1
)
return
loss
class
WeightedCrossEntropyLoss
(
nn
.
Module
):
"""
Transform input to fit the fomation of PyTorch offical cross entropy loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment