Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
52310ad9
Commit
52310ad9
authored
Nov 24, 2021
by
Shaoshuai Shi
Browse files
support CenterHead / CenterPoint (1stage), add its WOD config
parent
0ccbbaae
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
2 deletions
+132
-2
pcdet/models/dense_heads/__init__.py
pcdet/models/dense_heads/__init__.py
+2
-0
pcdet/models/detectors/__init__.py
pcdet/models/detectors/__init__.py
+3
-1
pcdet/models/detectors/detector3d_template.py
pcdet/models/detectors/detector3d_template.py
+2
-1
pcdet/utils/loss_utils.py
pcdet/utils/loss_utils.py
+125
-0
No files found.
pcdet/models/dense_heads/__init__.py
View file @
52310ad9
...
@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
...
@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from
.point_head_box
import
PointHeadBox
from
.point_head_box
import
PointHeadBox
from
.point_head_simple
import
PointHeadSimple
from
.point_head_simple
import
PointHeadSimple
from
.point_intra_part_head
import
PointIntraPartOffsetHead
from
.point_intra_part_head
import
PointIntraPartOffsetHead
from
.center_head
import
CenterHead
__all__
=
{
__all__
=
{
'AnchorHeadTemplate'
:
AnchorHeadTemplate
,
'AnchorHeadTemplate'
:
AnchorHeadTemplate
,
...
@@ -12,4 +13,5 @@ __all__ = {
...
@@ -12,4 +13,5 @@ __all__ = {
'PointHeadSimple'
:
PointHeadSimple
,
'PointHeadSimple'
:
PointHeadSimple
,
'PointHeadBox'
:
PointHeadBox
,
'PointHeadBox'
:
PointHeadBox
,
'AnchorHeadMulti'
:
AnchorHeadMulti
,
'AnchorHeadMulti'
:
AnchorHeadMulti
,
'CenterHead'
:
CenterHead
}
}
pcdet/models/detectors/__init__.py
View file @
52310ad9
...
@@ -7,6 +7,7 @@ from .second_net import SECONDNet
...
@@ -7,6 +7,7 @@ from .second_net import SECONDNet
from
.second_net_iou
import
SECONDNetIoU
from
.second_net_iou
import
SECONDNetIoU
from
.caddn
import
CaDDN
from
.caddn
import
CaDDN
from
.voxel_rcnn
import
VoxelRCNN
from
.voxel_rcnn
import
VoxelRCNN
from
.centerpoint
import
CenterPoint
__all__
=
{
__all__
=
{
'Detector3DTemplate'
:
Detector3DTemplate
,
'Detector3DTemplate'
:
Detector3DTemplate
,
...
@@ -17,7 +18,8 @@ __all__ = {
...
@@ -17,7 +18,8 @@ __all__ = {
'PointRCNN'
:
PointRCNN
,
'PointRCNN'
:
PointRCNN
,
'SECONDNetIoU'
:
SECONDNetIoU
,
'SECONDNetIoU'
:
SECONDNetIoU
,
'CaDDN'
:
CaDDN
,
'CaDDN'
:
CaDDN
,
'VoxelRCNN'
:
VoxelRCNN
'VoxelRCNN'
:
VoxelRCNN
,
'CenterPoint'
:
CenterPoint
}
}
...
...
pcdet/models/detectors/detector3d_template.py
View file @
52310ad9
...
@@ -132,7 +132,8 @@ class Detector3DTemplate(nn.Module):
...
@@ -132,7 +132,8 @@ class Detector3DTemplate(nn.Module):
class_names
=
self
.
class_names
,
class_names
=
self
.
class_names
,
grid_size
=
model_info_dict
[
'grid_size'
],
grid_size
=
model_info_dict
[
'grid_size'
],
point_cloud_range
=
model_info_dict
[
'point_cloud_range'
],
point_cloud_range
=
model_info_dict
[
'point_cloud_range'
],
predict_boxes_when_training
=
self
.
model_cfg
.
get
(
'ROI_HEAD'
,
False
)
predict_boxes_when_training
=
self
.
model_cfg
.
get
(
'ROI_HEAD'
,
False
),
voxel_size
=
model_info_dict
.
get
(
'voxel_size'
,
False
)
)
)
model_info_dict
[
'module_list'
].
append
(
dense_head_module
)
model_info_dict
[
'module_list'
].
append
(
dense_head_module
)
return
dense_head_module
,
model_info_dict
return
dense_head_module
,
model_info_dict
...
...
pcdet/utils/loss_utils.py
View file @
52310ad9
...
@@ -259,3 +259,128 @@ def compute_fg_mask(gt_boxes2d, shape, downsample_factor=1, device=torch.device(
...
@@ -259,3 +259,128 @@ def compute_fg_mask(gt_boxes2d, shape, downsample_factor=1, device=torch.device(
fg_mask
[
b
,
v1
:
v2
,
u1
:
u2
]
=
True
fg_mask
[
b
,
v1
:
v2
,
u1
:
u2
]
=
True
return
fg_mask
return
fg_mask
def
neg_loss_cornernet
(
pred
,
gt
,
mask
=
None
):
"""
Refer to https://github.com/tianweiy/CenterPoint.
Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
Args:
pred: (batch x c x h x w)
gt: (batch x c x h x w)
mask: (batch x h x w)
Returns:
"""
pos_inds
=
gt
.
eq
(
1
).
float
()
neg_inds
=
gt
.
lt
(
1
).
float
()
neg_weights
=
torch
.
pow
(
1
-
gt
,
4
)
loss
=
0
pos_loss
=
torch
.
log
(
pred
)
*
torch
.
pow
(
1
-
pred
,
2
)
*
pos_inds
neg_loss
=
torch
.
log
(
1
-
pred
)
*
torch
.
pow
(
pred
,
2
)
*
neg_weights
*
neg_inds
if
mask
is
not
None
:
mask
=
mask
[:,
None
,
:,
:].
float
()
pos_loss
=
pos_loss
*
mask
neg_loss
=
neg_loss
*
mask
num_pos
=
(
pos_inds
.
float
()
*
mask
).
sum
()
else
:
num_pos
=
pos_inds
.
float
().
sum
()
pos_loss
=
pos_loss
.
sum
()
neg_loss
=
neg_loss
.
sum
()
if
num_pos
==
0
:
loss
=
loss
-
neg_loss
else
:
loss
=
loss
-
(
pos_loss
+
neg_loss
)
/
num_pos
return
loss
class
FocalLossCenterNet
(
nn
.
Module
):
"""
Refer to https://github.com/tianweiy/CenterPoint
"""
def
__init__
(
self
):
super
(
FocalLossCenterNet
,
self
).
__init__
()
self
.
neg_loss
=
neg_loss_cornernet
def
forward
(
self
,
out
,
target
,
mask
=
None
):
return
self
.
neg_loss
(
out
,
target
,
mask
=
mask
)
def
_reg_loss
(
regr
,
gt_regr
,
mask
):
"""
Refer to https://github.com/tianweiy/CenterPoint
L1 regression loss
Args:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
Returns:
"""
num
=
mask
.
float
().
sum
()
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
gt_regr
).
float
()
isnotnan
=
(
~
torch
.
isnan
(
gt_regr
)).
float
()
mask
*=
isnotnan
regr
=
regr
*
mask
gt_regr
=
gt_regr
*
mask
loss
=
torch
.
abs
(
regr
-
gt_regr
)
loss
=
loss
.
transpose
(
2
,
0
)
loss
=
torch
.
sum
(
loss
,
dim
=
2
)
loss
=
torch
.
sum
(
loss
,
dim
=
1
)
# else:
# # D x M x B
# loss = loss.reshape(loss.shape[0], -1)
# loss = loss / (num + 1e-4)
loss
=
loss
/
torch
.
clamp_min
(
num
,
min
=
1.0
)
# import pdb; pdb.set_trace()
return
loss
def
_gather_feat
(
feat
,
ind
,
mask
=
None
):
dim
=
feat
.
size
(
2
)
ind
=
ind
.
unsqueeze
(
2
).
expand
(
ind
.
size
(
0
),
ind
.
size
(
1
),
dim
)
feat
=
feat
.
gather
(
1
,
ind
)
if
mask
is
not
None
:
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
feat
)
feat
=
feat
[
mask
]
feat
=
feat
.
view
(
-
1
,
dim
)
return
feat
def
_transpose_and_gather_feat
(
feat
,
ind
):
feat
=
feat
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
feat
=
feat
.
view
(
feat
.
size
(
0
),
-
1
,
feat
.
size
(
3
))
feat
=
_gather_feat
(
feat
,
ind
)
return
feat
class
RegLossCenterNet
(
nn
.
Module
):
"""
Refer to https://github.com/tianweiy/CenterPoint
"""
def
__init__
(
self
):
super
(
RegLossCenterNet
,
self
).
__init__
()
def
forward
(
self
,
output
,
mask
,
ind
=
None
,
target
=
None
):
"""
Args:
output: (batch x dim x h x w) or (batch x max_objects)
mask: (batch x max_objects)
ind: (batch x max_objects)
target: (batch x max_objects x dim)
Returns:
"""
if
ind
is
None
:
pred
=
output
else
:
pred
=
_transpose_and_gather_feat
(
output
,
ind
)
loss
=
_reg_loss
(
pred
,
target
,
mask
)
return
loss
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment