Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
8a64de5d
Commit
8a64de5d
authored
May 08, 2023
by
chenshi3
Browse files
Add support for BEVFusion
parent
c5dfdd71
Changes
20
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2444 additions
and
5 deletions
+2444
-5
pcdet/models/backbones_2d/fuser/__init__.py
pcdet/models/backbones_2d/fuser/__init__.py
+4
-0
pcdet/models/backbones_2d/fuser/convfuser.py
pcdet/models/backbones_2d/fuser/convfuser.py
+33
-0
pcdet/models/backbones_image/__init__.py
pcdet/models/backbones_image/__init__.py
+4
-0
pcdet/models/backbones_image/img_neck/__init__.py
pcdet/models/backbones_image/img_neck/__init__.py
+4
-0
pcdet/models/backbones_image/img_neck/generalized_lss.py
pcdet/models/backbones_image/img_neck/generalized_lss.py
+76
-0
pcdet/models/backbones_image/swin.py
pcdet/models/backbones_image/swin.py
+736
-0
pcdet/models/detectors/__init__.py
pcdet/models/detectors/__init__.py
+2
-0
pcdet/models/detectors/bevfusion.py
pcdet/models/detectors/bevfusion.py
+101
-0
pcdet/models/model_utils/swin_utils.py
pcdet/models/model_utils/swin_utils.py
+659
-0
pcdet/models/view_transforms/__init__.py
pcdet/models/view_transforms/__init__.py
+4
-0
pcdet/models/view_transforms/depth_lss.py
pcdet/models/view_transforms/depth_lss.py
+258
-0
pcdet/ops/bev_pool/__init__.py
pcdet/ops/bev_pool/__init__.py
+1
-0
pcdet/ops/bev_pool/bev_pool.py
pcdet/ops/bev_pool/bev_pool.py
+97
-0
pcdet/ops/bev_pool/src/bev_pool.cpp
pcdet/ops/bev_pool/src/bev_pool.cpp
+94
-0
pcdet/ops/bev_pool/src/bev_pool_cuda.cu
pcdet/ops/bev_pool/src/bev_pool_cuda.cu
+98
-0
setup.py
setup.py
+8
-0
tools/cfgs/nuscenes_models/cbgs_bevfusion.yaml
tools/cfgs/nuscenes_models/cbgs_bevfusion.yaml
+208
-0
tools/train_utils/optimization/__init__.py
tools/train_utils/optimization/__init__.py
+6
-2
tools/train_utils/optimization/learning_schedules_fastai.py
tools/train_utils/optimization/learning_schedules_fastai.py
+50
-2
tools/train_utils/train_utils.py
tools/train_utils/train_utils.py
+1
-1
No files found.
pcdet/models/backbones_2d/fuser/__init__.py
0 → 100644
View file @
8a64de5d
from
.convfuser
import
ConvFuser
__all__
=
{
'ConvFuser'
:
ConvFuser
}
\ No newline at end of file
pcdet/models/backbones_2d/fuser/convfuser.py
0 → 100644
View file @
8a64de5d
import
torch
from
torch
import
nn
class
ConvFuser
(
nn
.
Module
):
def
__init__
(
self
,
model_cfg
)
->
None
:
super
().
__init__
()
self
.
model_cfg
=
model_cfg
in_channel
=
self
.
model_cfg
.
IN_CHANNEL
out_channel
=
self
.
model_cfg
.
OUT_CHANNEL
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channel
,
out_channel
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
True
)
)
def
forward
(
self
,
batch_dict
):
"""
Args:
batch_dict:
spatial_features_img (tensor): Bev features from image modality
spatial_features (tensor): Bev features from lidar modality
Returns:
batch_dict:
spatial_features (tensor): Bev features after muli-modal fusion
"""
img_bev
=
batch_dict
[
'spatial_features_img'
]
lidar_bev
=
batch_dict
[
'spatial_features'
]
cat_bev
=
torch
.
cat
([
img_bev
,
lidar_bev
],
dim
=
1
)
mm_bev
=
self
.
conv
(
cat_bev
)
batch_dict
[
'spatial_features'
]
=
mm_bev
return
batch_dict
\ No newline at end of file
pcdet/models/backbones_image/__init__.py
0 → 100644
View file @
8a64de5d
from
.swin
import
SwinTransformer
__all__
=
{
'SwinTransformer'
:
SwinTransformer
,
}
\ No newline at end of file
pcdet/models/backbones_image/img_neck/__init__.py
0 → 100644
View file @
8a64de5d
from
.generalized_lss
import
GeneralizedLSSFPN
__all__
=
{
'GeneralizedLSSFPN'
:
GeneralizedLSSFPN
,
}
\ No newline at end of file
pcdet/models/backbones_image/img_neck/generalized_lss.py
0 → 100644
View file @
8a64de5d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...model_utils.basic_block_2d
import
BasicBlock2D
class
GeneralizedLSSFPN
(
nn
.
Module
):
"""
This module implements FPN, which creates pyramid features built on top of some input feature maps.
This code is adapted from https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/necks/fpn.py with minimal modifications.
"""
def
__init__
(
self
,
model_cfg
):
super
().
__init__
()
self
.
model_cfg
=
model_cfg
in_channels
=
self
.
model_cfg
.
IN_CHANNELS
out_channels
=
self
.
model_cfg
.
OUT_CHANNELS
num_ins
=
len
(
in_channels
)
num_outs
=
self
.
model_cfg
.
NUM_OUTS
start_level
=
self
.
model_cfg
.
START_LEVEL
end_level
=
self
.
model_cfg
.
END_LEVEL
self
.
in_channels
=
in_channels
if
end_level
==
-
1
:
self
.
backbone_end_level
=
num_ins
-
1
else
:
self
.
backbone_end_level
=
end_level
assert
end_level
<=
len
(
in_channels
)
assert
num_outs
==
end_level
-
start_level
self
.
start_level
=
start_level
self
.
end_level
=
end_level
self
.
lateral_convs
=
nn
.
ModuleList
()
self
.
fpn_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
start_level
,
self
.
backbone_end_level
):
l_conv
=
BasicBlock2D
(
in_channels
[
i
]
+
(
in_channels
[
i
+
1
]
if
i
==
self
.
backbone_end_level
-
1
else
out_channels
),
out_channels
,
kernel_size
=
1
,
bias
=
False
)
fpn_conv
=
BasicBlock2D
(
out_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
lateral_convs
.
append
(
l_conv
)
self
.
fpn_convs
.
append
(
fpn_conv
)
def
forward
(
self
,
batch_dict
):
"""
Args:
batch_dict:
image_features (list[tensor]): Multi-stage features from image backbone.
Returns:
batch_dict:
image_fpn (list(tensor)): FPN features.
"""
# upsample -> cat -> conv1x1 -> conv3x3
inputs
=
batch_dict
[
'image_features'
]
assert
len
(
inputs
)
==
len
(
self
.
in_channels
)
# build laterals
laterals
=
[
inputs
[
i
+
self
.
start_level
]
for
i
in
range
(
len
(
inputs
))]
# build top-down path
used_backbone_levels
=
len
(
laterals
)
-
1
for
i
in
range
(
used_backbone_levels
-
1
,
-
1
,
-
1
):
x
=
F
.
interpolate
(
laterals
[
i
+
1
],
size
=
laterals
[
i
].
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
,
)
laterals
[
i
]
=
torch
.
cat
([
laterals
[
i
],
x
],
dim
=
1
)
laterals
[
i
]
=
self
.
lateral_convs
[
i
](
laterals
[
i
])
laterals
[
i
]
=
self
.
fpn_convs
[
i
](
laterals
[
i
])
# build outputs
outs
=
[
laterals
[
i
]
for
i
in
range
(
used_backbone_levels
)]
batch_dict
[
'image_fpn'
]
=
tuple
(
outs
)
return
batch_dict
pcdet/models/backbones_image/swin.py
0 → 100644
View file @
8a64de5d
This diff is collapsed.
Click to expand it.
pcdet/models/detectors/__init__.py
View file @
8a64de5d
...
...
@@ -14,6 +14,7 @@ from .mppnet_e2e import MPPNetE2E
from
.pillarnet
import
PillarNet
from
.voxelnext
import
VoxelNeXt
from
.transfusion
import
TransFusion
from
.bevfusion
import
BevFusion
__all__
=
{
'Detector3DTemplate'
:
Detector3DTemplate
,
...
...
@@ -33,6 +34,7 @@ __all__ = {
'PillarNet'
:
PillarNet
,
'VoxelNeXt'
:
VoxelNeXt
,
'TransFusion'
:
TransFusion
,
'BevFusion'
:
BevFusion
,
}
...
...
pcdet/models/detectors/bevfusion.py
0 → 100644
View file @
8a64de5d
from
.detector3d_template
import
Detector3DTemplate
from
..
import
backbones_image
,
view_transforms
from
..backbones_image
import
img_neck
from
..backbones_2d
import
fuser
class
BevFusion
(
Detector3DTemplate
):
def
__init__
(
self
,
model_cfg
,
num_class
,
dataset
):
super
().
__init__
(
model_cfg
=
model_cfg
,
num_class
=
num_class
,
dataset
=
dataset
)
self
.
module_topology
=
[
'vfe'
,
'backbone_3d'
,
'map_to_bev_module'
,
'pfe'
,
'image_backbone'
,
'neck'
,
'vtransform'
,
'fuser'
,
'backbone_2d'
,
'dense_head'
,
'point_head'
,
'roi_head'
]
self
.
module_list
=
self
.
build_networks
()
def
build_neck
(
self
,
model_info_dict
):
if
self
.
model_cfg
.
get
(
'NECK'
,
None
)
is
None
:
return
None
,
model_info_dict
neck_module
=
img_neck
.
__all__
[
self
.
model_cfg
.
NECK
.
NAME
](
model_cfg
=
self
.
model_cfg
.
NECK
)
model_info_dict
[
'module_list'
].
append
(
neck_module
)
return
neck_module
,
model_info_dict
def
build_vtransform
(
self
,
model_info_dict
):
if
self
.
model_cfg
.
get
(
'VTRANSFORM'
,
None
)
is
None
:
return
None
,
model_info_dict
vtransform_module
=
view_transforms
.
__all__
[
self
.
model_cfg
.
VTRANSFORM
.
NAME
](
model_cfg
=
self
.
model_cfg
.
VTRANSFORM
)
model_info_dict
[
'module_list'
].
append
(
vtransform_module
)
return
vtransform_module
,
model_info_dict
def
build_image_backbone
(
self
,
model_info_dict
):
if
self
.
model_cfg
.
get
(
'IMAGE_BACKBONE'
,
None
)
is
None
:
return
None
,
model_info_dict
image_backbone_module
=
backbones_image
.
__all__
[
self
.
model_cfg
.
IMAGE_BACKBONE
.
NAME
](
model_cfg
=
self
.
model_cfg
.
IMAGE_BACKBONE
)
image_backbone_module
.
init_weights
()
model_info_dict
[
'module_list'
].
append
(
image_backbone_module
)
return
image_backbone_module
,
model_info_dict
def
build_fuser
(
self
,
model_info_dict
):
if
self
.
model_cfg
.
get
(
'FUSER'
,
None
)
is
None
:
return
None
,
model_info_dict
fuser_module
=
fuser
.
__all__
[
self
.
model_cfg
.
FUSER
.
NAME
](
model_cfg
=
self
.
model_cfg
.
FUSER
)
model_info_dict
[
'module_list'
].
append
(
fuser_module
)
model_info_dict
[
'num_bev_features'
]
=
self
.
model_cfg
.
FUSER
.
OUT_CHANNEL
return
fuser_module
,
model_info_dict
def
forward
(
self
,
batch_dict
):
for
i
,
cur_module
in
enumerate
(
self
.
module_list
):
batch_dict
=
cur_module
(
batch_dict
)
if
self
.
training
:
loss
,
tb_dict
,
disp_dict
=
self
.
get_training_loss
(
batch_dict
)
ret_dict
=
{
'loss'
:
loss
}
return
ret_dict
,
tb_dict
,
disp_dict
else
:
pred_dicts
,
recall_dicts
=
self
.
post_processing
(
batch_dict
)
return
pred_dicts
,
recall_dicts
def
get_training_loss
(
self
,
batch_dict
):
disp_dict
=
{}
loss_trans
,
tb_dict
=
batch_dict
[
'loss'
],
batch_dict
[
'tb_dict'
]
tb_dict
=
{
'loss_trans'
:
loss_trans
.
item
(),
**
tb_dict
}
loss
=
loss_trans
return
loss
,
tb_dict
,
disp_dict
def
post_processing
(
self
,
batch_dict
):
post_process_cfg
=
self
.
model_cfg
.
POST_PROCESSING
batch_size
=
batch_dict
[
'batch_size'
]
final_pred_dict
=
batch_dict
[
'final_box_dicts'
]
recall_dict
=
{}
for
index
in
range
(
batch_size
):
pred_boxes
=
final_pred_dict
[
index
][
'pred_boxes'
]
recall_dict
=
self
.
generate_recall_record
(
box_preds
=
pred_boxes
,
recall_dict
=
recall_dict
,
batch_index
=
index
,
data_dict
=
batch_dict
,
thresh_list
=
post_process_cfg
.
RECALL_THRESH_LIST
)
return
final_pred_dict
,
recall_dict
pcdet/models/model_utils/swin_utils.py
0 → 100644
View file @
8a64de5d
This diff is collapsed.
Click to expand it.
pcdet/models/view_transforms/__init__.py
0 → 100644
View file @
8a64de5d
from
.depth_lss
import
DepthLSSTransform
__all__
=
{
'DepthLSSTransform'
:
DepthLSSTransform
,
}
\ No newline at end of file
pcdet/models/view_transforms/depth_lss.py
0 → 100644
View file @
8a64de5d
import
torch
from
torch
import
nn
from
pcdet.ops.bev_pool
import
bev_pool
def
gen_dx_bx
(
xbound
,
ybound
,
zbound
):
dx
=
torch
.
Tensor
([
row
[
2
]
for
row
in
[
xbound
,
ybound
,
zbound
]])
bx
=
torch
.
Tensor
([
row
[
0
]
+
row
[
2
]
/
2.0
for
row
in
[
xbound
,
ybound
,
zbound
]])
nx
=
torch
.
LongTensor
(
[(
row
[
1
]
-
row
[
0
])
/
row
[
2
]
for
row
in
[
xbound
,
ybound
,
zbound
]]
)
return
dx
,
bx
,
nx
class
DepthLSSTransform
(
nn
.
Module
):
"""
This module implements LSS, which lists images into 3D and then splats onto bev features.
This code is adapted from https://github.com/mit-han-lab/bevfusion/ with minimal modifications.
"""
def
__init__
(
self
,
model_cfg
):
super
().
__init__
()
self
.
model_cfg
=
model_cfg
in_channel
=
self
.
model_cfg
.
IN_CHANNEL
out_channel
=
self
.
model_cfg
.
OUT_CHANNEL
self
.
image_size
=
self
.
model_cfg
.
IMAGE_SIZE
self
.
feature_size
=
self
.
model_cfg
.
FEATURE_SIZE
xbound
=
self
.
model_cfg
.
XBOUND
ybound
=
self
.
model_cfg
.
YBOUND
zbound
=
self
.
model_cfg
.
ZBOUND
self
.
dbound
=
self
.
model_cfg
.
DBOUND
downsample
=
self
.
model_cfg
.
DOWNSAMPLE
dx
,
bx
,
nx
=
gen_dx_bx
(
xbound
,
ybound
,
zbound
)
self
.
dx
=
nn
.
Parameter
(
dx
,
requires_grad
=
False
)
self
.
bx
=
nn
.
Parameter
(
bx
,
requires_grad
=
False
)
self
.
nx
=
nn
.
Parameter
(
nx
,
requires_grad
=
False
)
self
.
C
=
out_channel
self
.
frustum
=
self
.
create_frustum
()
self
.
D
=
self
.
frustum
.
shape
[
0
]
self
.
dtransform
=
nn
.
Sequential
(
nn
.
Conv2d
(
1
,
8
,
1
),
nn
.
BatchNorm2d
(
8
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
8
,
32
,
5
,
stride
=
4
,
padding
=
2
),
nn
.
BatchNorm2d
(
32
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
32
,
64
,
5
,
stride
=
2
,
padding
=
2
),
nn
.
BatchNorm2d
(
64
),
nn
.
ReLU
(
True
),
)
self
.
depthnet
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channel
+
64
,
in_channel
,
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
in_channel
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
in_channel
,
in_channel
,
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
in_channel
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
in_channel
,
self
.
D
+
self
.
C
,
1
),
)
if
downsample
>
1
:
assert
downsample
==
2
,
downsample
self
.
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
out_channel
,
out_channel
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
out_channel
,
out_channel
,
3
,
stride
=
downsample
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
out_channel
,
out_channel
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
True
),
)
else
:
self
.
downsample
=
nn
.
Identity
()
def
create_frustum
(
self
):
iH
,
iW
=
self
.
image_size
fH
,
fW
=
self
.
feature_size
ds
=
torch
.
arange
(
*
self
.
dbound
,
dtype
=
torch
.
float
).
view
(
-
1
,
1
,
1
).
expand
(
-
1
,
fH
,
fW
)
D
,
_
,
_
=
ds
.
shape
xs
=
torch
.
linspace
(
0
,
iW
-
1
,
fW
,
dtype
=
torch
.
float
).
view
(
1
,
1
,
fW
).
expand
(
D
,
fH
,
fW
)
ys
=
torch
.
linspace
(
0
,
iH
-
1
,
fH
,
dtype
=
torch
.
float
).
view
(
1
,
fH
,
1
).
expand
(
D
,
fH
,
fW
)
frustum
=
torch
.
stack
((
xs
,
ys
,
ds
),
-
1
)
return
nn
.
Parameter
(
frustum
,
requires_grad
=
False
)
def
get_geometry
(
self
,
camera2lidar_rots
,
camera2lidar_trans
,
intrins
,
post_rots
,
post_trans
,
**
kwargs
):
camera2lidar_rots
=
camera2lidar_rots
.
to
(
torch
.
float
)
camera2lidar_trans
=
camera2lidar_trans
.
to
(
torch
.
float
)
intrins
=
intrins
.
to
(
torch
.
float
)
post_rots
=
post_rots
.
to
(
torch
.
float
)
post_trans
=
post_trans
.
to
(
torch
.
float
)
B
,
N
,
_
=
camera2lidar_trans
.
shape
# undo post-transformation
# B x N x D x H x W x 3
points
=
self
.
frustum
-
post_trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
points
=
torch
.
inverse
(
post_rots
).
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
.
unsqueeze
(
-
1
))
# cam_to_lidar
points
=
torch
.
cat
((
points
[:,
:,
:,
:,
:,
:
2
]
*
points
[:,
:,
:,
:,
:,
2
:
3
],
points
[:,
:,
:,
:,
:,
2
:
3
]),
5
)
combine
=
camera2lidar_rots
.
matmul
(
torch
.
inverse
(
intrins
))
points
=
combine
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
).
squeeze
(
-
1
)
points
+=
camera2lidar_trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
if
"extra_rots"
in
kwargs
:
extra_rots
=
kwargs
[
"extra_rots"
]
points
=
extra_rots
.
view
(
B
,
1
,
1
,
1
,
1
,
3
,
3
).
repeat
(
1
,
N
,
1
,
1
,
1
,
1
,
1
)
\
.
matmul
(
points
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
if
"extra_trans"
in
kwargs
:
extra_trans
=
kwargs
[
"extra_trans"
]
points
+=
extra_trans
.
view
(
B
,
1
,
1
,
1
,
1
,
3
).
repeat
(
1
,
N
,
1
,
1
,
1
,
1
)
return
points
def
bev_pool
(
self
,
geom_feats
,
x
):
geom_feats
=
geom_feats
.
to
(
torch
.
float
)
x
=
x
.
to
(
torch
.
float
)
B
,
N
,
D
,
H
,
W
,
C
=
x
.
shape
Nprime
=
B
*
N
*
D
*
H
*
W
# flatten x
x
=
x
.
reshape
(
Nprime
,
C
)
# flatten indices
geom_feats
=
((
geom_feats
-
(
self
.
bx
-
self
.
dx
/
2.0
))
/
self
.
dx
).
long
()
geom_feats
=
geom_feats
.
view
(
Nprime
,
3
)
batch_ix
=
torch
.
cat
([
torch
.
full
([
Nprime
//
B
,
1
],
ix
,
device
=
x
.
device
,
dtype
=
torch
.
long
)
for
ix
in
range
(
B
)])
geom_feats
=
torch
.
cat
((
geom_feats
,
batch_ix
),
1
)
# filter out points that are outside box
kept
=
(
(
geom_feats
[:,
0
]
>=
0
)
&
(
geom_feats
[:,
0
]
<
self
.
nx
[
0
])
&
(
geom_feats
[:,
1
]
>=
0
)
&
(
geom_feats
[:,
1
]
<
self
.
nx
[
1
])
&
(
geom_feats
[:,
2
]
>=
0
)
&
(
geom_feats
[:,
2
]
<
self
.
nx
[
2
])
)
x
=
x
[
kept
]
geom_feats
=
geom_feats
[
kept
]
x
=
bev_pool
(
x
,
geom_feats
,
B
,
self
.
nx
[
2
],
self
.
nx
[
0
],
self
.
nx
[
1
])
# collapse Z
final
=
torch
.
cat
(
x
.
unbind
(
dim
=
2
),
1
)
return
final
def
get_cam_feats
(
self
,
x
,
d
):
B
,
N
,
C
,
fH
,
fW
=
x
.
shape
d
=
d
.
view
(
B
*
N
,
*
d
.
shape
[
2
:])
x
=
x
.
view
(
B
*
N
,
C
,
fH
,
fW
)
d
=
self
.
dtransform
(
d
)
x
=
torch
.
cat
([
d
,
x
],
dim
=
1
)
x
=
self
.
depthnet
(
x
)
depth
=
x
[:,
:
self
.
D
].
softmax
(
dim
=
1
)
x
=
depth
.
unsqueeze
(
1
)
*
x
[:,
self
.
D
:
(
self
.
D
+
self
.
C
)].
unsqueeze
(
2
)
x
=
x
.
view
(
B
,
N
,
self
.
C
,
self
.
D
,
fH
,
fW
)
x
=
x
.
permute
(
0
,
1
,
3
,
4
,
5
,
2
)
return
x
def
forward
(
self
,
batch_dict
):
"""
Args:
batch_dict:
image_fpn (list[tensor]): image features after image neck
Returns:
batch_dict:
spatial_features_img (tensor): bev features from image modality
"""
x
=
batch_dict
[
'image_fpn'
]
x
=
x
[
0
]
BN
,
C
,
H
,
W
=
x
.
size
()
img
=
x
.
view
(
int
(
BN
/
6
),
6
,
C
,
H
,
W
)
camera_intrinsics
=
batch_dict
[
'camera_intrinsics'
]
camera2lidar
=
batch_dict
[
'camera2lidar'
]
img_aug_matrix
=
batch_dict
[
'img_aug_matrix'
]
lidar_aug_matrix
=
batch_dict
[
'lidar_aug_matrix'
]
lidar2image
=
batch_dict
[
'lidar2image'
]
intrins
=
camera_intrinsics
[...,
:
3
,
:
3
]
post_rots
=
img_aug_matrix
[...,
:
3
,
:
3
]
post_trans
=
img_aug_matrix
[...,
:
3
,
3
]
camera2lidar_rots
=
camera2lidar
[...,
:
3
,
:
3
]
camera2lidar_trans
=
camera2lidar
[...,
:
3
,
3
]
points
=
batch_dict
[
'points'
]
batch_size
=
BN
//
6
depth
=
torch
.
zeros
(
batch_size
,
img
.
shape
[
1
],
1
,
*
self
.
image_size
).
to
(
points
[
0
].
device
)
for
b
in
range
(
batch_size
):
batch_mask
=
points
[:,
0
]
==
b
cur_coords
=
points
[
batch_mask
][:,
1
:
4
]
cur_img_aug_matrix
=
img_aug_matrix
[
b
]
cur_lidar_aug_matrix
=
lidar_aug_matrix
[
b
]
cur_lidar2image
=
lidar2image
[
b
]
# inverse aug
cur_coords
-=
cur_lidar_aug_matrix
[:
3
,
3
]
cur_coords
=
torch
.
inverse
(
cur_lidar_aug_matrix
[:
3
,
:
3
]).
matmul
(
cur_coords
.
transpose
(
1
,
0
)
)
# lidar2image
cur_coords
=
cur_lidar2image
[:,
:
3
,
:
3
].
matmul
(
cur_coords
)
cur_coords
+=
cur_lidar2image
[:,
:
3
,
3
].
reshape
(
-
1
,
3
,
1
)
# get 2d coords
dist
=
cur_coords
[:,
2
,
:]
cur_coords
[:,
2
,
:]
=
torch
.
clamp
(
cur_coords
[:,
2
,
:],
1e-5
,
1e5
)
cur_coords
[:,
:
2
,
:]
/=
cur_coords
[:,
2
:
3
,
:]
# do image aug
cur_coords
=
cur_img_aug_matrix
[:,
:
3
,
:
3
].
matmul
(
cur_coords
)
cur_coords
+=
cur_img_aug_matrix
[:,
:
3
,
3
].
reshape
(
-
1
,
3
,
1
)
cur_coords
=
cur_coords
[:,
:
2
,
:].
transpose
(
1
,
2
)
# normalize coords for grid sample
cur_coords
=
cur_coords
[...,
[
1
,
0
]]
# filter points outside of images
on_img
=
(
(
cur_coords
[...,
0
]
<
self
.
image_size
[
0
])
&
(
cur_coords
[...,
0
]
>=
0
)
&
(
cur_coords
[...,
1
]
<
self
.
image_size
[
1
])
&
(
cur_coords
[...,
1
]
>=
0
)
)
for
c
in
range
(
on_img
.
shape
[
0
]):
masked_coords
=
cur_coords
[
c
,
on_img
[
c
]].
long
()
masked_dist
=
dist
[
c
,
on_img
[
c
]]
depth
[
b
,
c
,
0
,
masked_coords
[:,
0
],
masked_coords
[:,
1
]]
=
masked_dist
extra_rots
=
lidar_aug_matrix
[...,
:
3
,
:
3
]
extra_trans
=
lidar_aug_matrix
[...,
:
3
,
3
]
geom
=
self
.
get_geometry
(
camera2lidar_rots
,
camera2lidar_trans
,
intrins
,
post_rots
,
post_trans
,
extra_rots
=
extra_rots
,
extra_trans
=
extra_trans
,
)
# use points depth to assist the depth prediction in images
x
=
self
.
get_cam_feats
(
img
,
depth
)
x
=
self
.
bev_pool
(
geom
,
x
)
x
=
self
.
downsample
(
x
)
# convert bev features from (b, c, x, y) to (b, c, y, x)
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
batch_dict
[
'spatial_features_img'
]
=
x
return
batch_dict
\ No newline at end of file
pcdet/ops/bev_pool/__init__.py
0 → 100644
View file @
8a64de5d
from
.bev_pool
import
bev_pool
\ No newline at end of file
pcdet/ops/bev_pool/bev_pool.py
0 → 100644
View file @
8a64de5d
import
torch
from
.
import
bev_pool_ext
__all__
=
[
"bev_pool"
]
class
QuickCumsum
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
geom_feats
,
ranks
):
x
=
x
.
cumsum
(
0
)
kept
=
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
bool
)
kept
[:
-
1
]
=
ranks
[
1
:]
!=
ranks
[:
-
1
]
x
,
geom_feats
=
x
[
kept
],
geom_feats
[
kept
]
x
=
torch
.
cat
((
x
[:
1
],
x
[
1
:]
-
x
[:
-
1
]))
# save kept for backward
ctx
.
save_for_backward
(
kept
)
# no gradient for geom_feats
ctx
.
mark_non_differentiable
(
geom_feats
)
return
x
,
geom_feats
@
staticmethod
def
backward
(
ctx
,
gradx
,
gradgeom
):
(
kept
,)
=
ctx
.
saved_tensors
back
=
torch
.
cumsum
(
kept
,
0
)
back
[
kept
]
-=
1
val
=
gradx
[
back
]
return
val
,
None
,
None
class
QuickCumsumCuda
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
geom_feats
,
ranks
,
B
,
D
,
H
,
W
):
kept
=
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
bool
)
kept
[
1
:]
=
ranks
[
1
:]
!=
ranks
[:
-
1
]
interval_starts
=
torch
.
where
(
kept
)[
0
].
int
()
interval_lengths
=
torch
.
zeros_like
(
interval_starts
)
interval_lengths
[:
-
1
]
=
interval_starts
[
1
:]
-
interval_starts
[:
-
1
]
interval_lengths
[
-
1
]
=
x
.
shape
[
0
]
-
interval_starts
[
-
1
]
geom_feats
=
geom_feats
.
int
()
out
=
bev_pool_ext
.
bev_pool_forward
(
x
,
geom_feats
,
interval_lengths
,
interval_starts
,
B
,
D
,
H
,
W
,
)
ctx
.
save_for_backward
(
interval_starts
,
interval_lengths
,
geom_feats
)
ctx
.
saved_shapes
=
B
,
D
,
H
,
W
return
out
@
staticmethod
def
backward
(
ctx
,
out_grad
):
interval_starts
,
interval_lengths
,
geom_feats
=
ctx
.
saved_tensors
B
,
D
,
H
,
W
=
ctx
.
saved_shapes
out_grad
=
out_grad
.
contiguous
()
x_grad
=
bev_pool_ext
.
bev_pool_backward
(
out_grad
,
geom_feats
,
interval_lengths
,
interval_starts
,
B
,
D
,
H
,
W
,
)
return
x_grad
,
None
,
None
,
None
,
None
,
None
,
None
def
bev_pool
(
feats
,
coords
,
B
,
D
,
H
,
W
):
assert
feats
.
shape
[
0
]
==
coords
.
shape
[
0
]
ranks
=
(
coords
[:,
0
]
*
(
W
*
D
*
B
)
+
coords
[:,
1
]
*
(
D
*
B
)
+
coords
[:,
2
]
*
B
+
coords
[:,
3
]
)
indices
=
ranks
.
argsort
()
feats
,
coords
,
ranks
=
feats
[
indices
],
coords
[
indices
],
ranks
[
indices
]
x
=
QuickCumsumCuda
.
apply
(
feats
,
coords
,
ranks
,
B
,
D
,
H
,
W
)
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
).
contiguous
()
return
x
pcdet/ops/bev_pool/src/bev_pool.cpp
0 → 100644
View file @
8a64de5d
#include <torch/torch.h>
#include <c10/cuda/CUDAGuard.h>
// CUDA function declarations
void
bev_pool
(
int
b
,
int
d
,
int
h
,
int
w
,
int
n
,
int
c
,
int
n_intervals
,
const
float
*
x
,
const
int
*
geom_feats
,
const
int
*
interval_starts
,
const
int
*
interval_lengths
,
float
*
out
);
void
bev_pool_grad
(
int
b
,
int
d
,
int
h
,
int
w
,
int
n
,
int
c
,
int
n_intervals
,
const
float
*
out_grad
,
const
int
*
geom_feats
,
const
int
*
interval_starts
,
const
int
*
interval_lengths
,
float
*
x_grad
);
/*
Function: pillar pooling (forward, cuda)
Args:
x : input features, FloatTensor[n, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
Return:
out : output features, FloatTensor[b, d, h, w, c]
*/
at
::
Tensor
bev_pool_forward
(
const
at
::
Tensor
_x
,
const
at
::
Tensor
_geom_feats
,
const
at
::
Tensor
_interval_lengths
,
const
at
::
Tensor
_interval_starts
,
int
b
,
int
d
,
int
h
,
int
w
)
{
int
n
=
_x
.
size
(
0
);
int
c
=
_x
.
size
(
1
);
int
n_intervals
=
_interval_lengths
.
size
(
0
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_x
));
const
float
*
x
=
_x
.
data_ptr
<
float
>
();
const
int
*
geom_feats
=
_geom_feats
.
data_ptr
<
int
>
();
const
int
*
interval_lengths
=
_interval_lengths
.
data_ptr
<
int
>
();
const
int
*
interval_starts
=
_interval_starts
.
data_ptr
<
int
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
_x
.
dtype
()).
device
(
_x
.
device
());
at
::
Tensor
_out
=
torch
::
zeros
({
b
,
d
,
h
,
w
,
c
},
options
);
float
*
out
=
_out
.
data_ptr
<
float
>
();
bev_pool
(
b
,
d
,
h
,
w
,
n
,
c
,
n_intervals
,
x
,
geom_feats
,
interval_starts
,
interval_lengths
,
out
);
return
_out
;
}
/*
Function: pillar pooling (backward, cuda)
Args:
out_grad : input features, FloatTensor[b, d, h, w, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
Return:
x_grad : output features, FloatTensor[n, 4]
*/
at
::
Tensor
bev_pool_backward
(
const
at
::
Tensor
_out_grad
,
const
at
::
Tensor
_geom_feats
,
const
at
::
Tensor
_interval_lengths
,
const
at
::
Tensor
_interval_starts
,
int
b
,
int
d
,
int
h
,
int
w
)
{
int
n
=
_geom_feats
.
size
(
0
);
int
c
=
_out_grad
.
size
(
4
);
int
n_intervals
=
_interval_lengths
.
size
(
0
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_out_grad
));
const
float
*
out_grad
=
_out_grad
.
data_ptr
<
float
>
();
const
int
*
geom_feats
=
_geom_feats
.
data_ptr
<
int
>
();
const
int
*
interval_lengths
=
_interval_lengths
.
data_ptr
<
int
>
();
const
int
*
interval_starts
=
_interval_starts
.
data_ptr
<
int
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
_out_grad
.
dtype
()).
device
(
_out_grad
.
device
());
at
::
Tensor
_x_grad
=
torch
::
zeros
({
n
,
c
},
options
);
float
*
x_grad
=
_x_grad
.
data_ptr
<
float
>
();
bev_pool_grad
(
b
,
d
,
h
,
w
,
n
,
c
,
n_intervals
,
out_grad
,
geom_feats
,
interval_starts
,
interval_lengths
,
x_grad
);
return
_x_grad
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"bev_pool_forward"
,
&
bev_pool_forward
,
"bev_pool_forward"
);
m
.
def
(
"bev_pool_backward"
,
&
bev_pool_backward
,
"bev_pool_backward"
);
}
pcdet/ops/bev_pool/src/bev_pool_cuda.cu
0 → 100644
View file @
8a64de5d
#include <stdio.h>
#include <stdlib.h>
/*
Function: pillar pooling
Args:
b : batch size
d : depth of the feature map
h : height of pooled feature map
w : width of pooled feature map
n : number of input points
c : number of channels
n_intervals : number of unique points
x : input features, FloatTensor[n, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
out : output features, FloatTensor[b, d, h, w, c]
*/
__global__
void
bev_pool_kernel
(
int
b
,
int
d
,
int
h
,
int
w
,
int
n
,
int
c
,
int
n_intervals
,
const
float
*
__restrict__
x
,
const
int
*
__restrict__
geom_feats
,
const
int
*
__restrict__
interval_starts
,
const
int
*
__restrict__
interval_lengths
,
float
*
__restrict__
out
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
idx
/
c
;
int
cur_c
=
idx
%
c
;
if
(
index
>=
n_intervals
)
return
;
int
interval_start
=
interval_starts
[
index
];
int
interval_length
=
interval_lengths
[
index
];
const
int
*
cur_geom_feats
=
geom_feats
+
interval_start
*
4
;
const
float
*
cur_x
=
x
+
interval_start
*
c
+
cur_c
;
float
*
cur_out
=
out
+
cur_geom_feats
[
3
]
*
d
*
h
*
w
*
c
+
cur_geom_feats
[
2
]
*
h
*
w
*
c
+
cur_geom_feats
[
0
]
*
w
*
c
+
cur_geom_feats
[
1
]
*
c
+
cur_c
;
float
psum
=
0
;
for
(
int
i
=
0
;
i
<
interval_length
;
i
++
){
psum
+=
cur_x
[
i
*
c
];
}
*
cur_out
=
psum
;
}
/*
Function: pillar pooling backward
Args:
b : batch size
d : depth of the feature map
h : height of pooled feature map
w : width of pooled feature map
n : number of input points
c : number of channels
n_intervals : number of unique points
out_grad : gradient of the BEV fmap from top, FloatTensor[b, d, h, w, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
x_grad : gradient of the image fmap, FloatTensor
*/
__global__
void
bev_pool_grad_kernel
(
int
b
,
int
d
,
int
h
,
int
w
,
int
n
,
int
c
,
int
n_intervals
,
const
float
*
__restrict__
out_grad
,
const
int
*
__restrict__
geom_feats
,
const
int
*
__restrict__
interval_starts
,
const
int
*
__restrict__
interval_lengths
,
float
*
__restrict__
x_grad
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
idx
/
c
;
int
cur_c
=
idx
%
c
;
if
(
index
>=
n_intervals
)
return
;
int
interval_start
=
interval_starts
[
index
];
int
interval_length
=
interval_lengths
[
index
];
const
int
*
cur_geom_feats
=
geom_feats
+
interval_start
*
4
;
float
*
cur_x_grad
=
x_grad
+
interval_start
*
c
+
cur_c
;
const
float
*
cur_out_grad
=
out_grad
+
cur_geom_feats
[
3
]
*
d
*
h
*
w
*
c
+
cur_geom_feats
[
2
]
*
h
*
w
*
c
+
cur_geom_feats
[
0
]
*
w
*
c
+
cur_geom_feats
[
1
]
*
c
+
cur_c
;
for
(
int
i
=
0
;
i
<
interval_length
;
i
++
){
cur_x_grad
[
i
*
c
]
=
*
cur_out_grad
;
}
}
void
bev_pool
(
int
b
,
int
d
,
int
h
,
int
w
,
int
n
,
int
c
,
int
n_intervals
,
const
float
*
x
,
const
int
*
geom_feats
,
const
int
*
interval_starts
,
const
int
*
interval_lengths
,
float
*
out
)
{
bev_pool_kernel
<<<
(
int
)
ceil
(((
double
)
n_intervals
*
c
/
256
)),
256
>>>
(
b
,
d
,
h
,
w
,
n
,
c
,
n_intervals
,
x
,
geom_feats
,
interval_starts
,
interval_lengths
,
out
);
}
void
bev_pool_grad
(
int
b
,
int
d
,
int
h
,
int
w
,
int
n
,
int
c
,
int
n_intervals
,
const
float
*
out_grad
,
const
int
*
geom_feats
,
const
int
*
interval_starts
,
const
int
*
interval_lengths
,
float
*
x_grad
)
{
bev_pool_grad_kernel
<<<
(
int
)
ceil
(((
double
)
n_intervals
*
c
/
256
)),
256
>>>
(
b
,
d
,
h
,
w
,
n
,
c
,
n_intervals
,
out_grad
,
geom_feats
,
interval_starts
,
interval_lengths
,
x_grad
);
}
setup.py
View file @
8a64de5d
...
...
@@ -117,5 +117,13 @@ if __name__ == '__main__':
],
),
make_cuda_ext
(
name
=
"bev_pool_ext"
,
module
=
"pcdet.ops.bev_pool"
,
sources
=
[
"src/bev_pool.cpp"
,
"src/bev_pool_cuda.cu"
,
],
),
],
)
tools/cfgs/nuscenes_models/cbgs_bevfusion.yaml
0 → 100644
View file @
8a64de5d
CLASS_NAMES
:
[
'
car'
,
'
truck'
,
'
construction_vehicle'
,
'
bus'
,
'
trailer'
,
'
barrier'
,
'
motorcycle'
,
'
bicycle'
,
'
pedestrian'
,
'
traffic_cone'
]
DATA_CONFIG
:
_BASE_CONFIG_
:
cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE
:
[
-54.0
,
-54.0
,
-5.0
,
54.0
,
54.0
,
3.0
]
CAMERA_CONFIG
:
USE_CAMERA
:
True
IMAGE
:
FINAL_DIM
:
[
256
,
704
]
RESIZE_LIM_TRAIN
:
[
0.38
,
0.55
]
RESIZE_LIM_TEST
:
[
0.48
,
0.48
]
DATA_AUGMENTOR
:
DISABLE_AUG_LIST
:
[
'
placeholder'
]
AUG_CONFIG_LIST
:
-
NAME
:
random_world_flip
ALONG_AXIS_LIST
:
[
'
x'
,
'
y'
]
-
NAME
:
random_world_rotation
WORLD_ROT_ANGLE
:
[
-0.78539816
,
0.78539816
]
-
NAME
:
random_world_scaling
WORLD_SCALE_RANGE
:
[
0.9
,
1.1
]
-
NAME
:
random_world_translation
NOISE_TRANSLATE_STD
:
[
0.5
,
0.5
,
0.5
]
-
NAME
:
imgaug
ROT_LIM
:
[
-5.4
,
5.4
]
RAND_FLIP
:
true
DATA_PROCESSOR
:
-
NAME
:
mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES
:
True
-
NAME
:
shuffle_points
SHUFFLE_ENABLED
:
{
'
train'
:
True
,
'
test'
:
True
}
-
NAME
:
transform_points_to_voxels
VOXEL_SIZE
:
[
0.075
,
0.075
,
0.2
]
MAX_POINTS_PER_VOXEL
:
10
MAX_NUMBER_OF_VOXELS
:
{
'
train'
:
120000
,
'
test'
:
160000
}
-
NAME
:
image_calibrate
-
NAME
:
image_normalize
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
MODEL
:
NAME
:
BevFusion
VFE
:
NAME
:
MeanVFE
BACKBONE_3D
:
NAME
:
VoxelResBackBone8x
USE_BIAS
:
False
MAP_TO_BEV
:
NAME
:
HeightCompression
NUM_BEV_FEATURES
:
256
IMAGE_BACKBONE
:
NAME
:
SwinTransformer
EMBED_DIMS
:
96
DEPTHS
:
[
2
,
2
,
6
,
2
]
NUM_HEADS
:
[
3
,
6
,
12
,
24
]
WINDOW_SIZE
:
7
MLP_RATIO
:
4
DROP_RATE
:
0.
ATTN_DROP_RATE
:
0.
DROP_PATH_RATE
:
0.2
PATCH_NORM
:
True
OUT_INDICES
:
[
1
,
2
,
3
]
WITH_CP
:
False
CONVERT_WEIGHTS
:
True
INIT_CFG
:
type
:
Pretrained
checkpoint
:
swint-nuimages-pretrained.pth
NECK
:
NAME
:
GeneralizedLSSFPN
IN_CHANNELS
:
[
192
,
384
,
768
]
OUT_CHANNELS
:
256
START_LEVEL
:
0
END_LEVEL
:
-1
NUM_OUTS
:
3
VTRANSFORM
:
NAME
:
DepthLSSTransform
IMAGE_SIZE
:
[
256
,
704
]
IN_CHANNEL
:
256
OUT_CHANNEL
:
80
FEATURE_SIZE
:
[
32
,
88
]
XBOUND
:
[
-54.0
,
54.0
,
0.3
]
YBOUND
:
[
-54.0
,
54.0
,
0.3
]
ZBOUND
:
[
-10.0
,
10.0
,
20.0
]
DBOUND
:
[
1.0
,
60.0
,
0.5
]
DOWNSAMPLE
:
2
FUSER
:
NAME
:
'
ConvFuser'
IN_CHANNEL
:
336
OUT_CHANNEL
:
256
BACKBONE_2D
:
NAME
:
BaseBEVBackbone
LAYER_NUMS
:
[
5
,
5
]
LAYER_STRIDES
:
[
1
,
2
]
NUM_FILTERS
:
[
128
,
256
]
UPSAMPLE_STRIDES
:
[
1
,
2
]
NUM_UPSAMPLE_FILTERS
:
[
256
,
256
]
USE_CONV_FOR_NO_STRIDE
:
true
DENSE_HEAD
:
CLASS_AGNOSTIC
:
False
NAME
:
TransFusionHead
USE_BIAS_BEFORE_NORM
:
False
NUM_PROPOSALS
:
200
HIDDEN_CHANNEL
:
128
NUM_CLASSES
:
10
NUM_HEADS
:
8
NMS_KERNEL_SIZE
:
3
FFN_CHANNEL
:
256
DROPOUT
:
0.1
BN_MOMENTUM
:
0.1
ACTIVATION
:
relu
NUM_HM_CONV
:
2
SEPARATE_HEAD_CFG
:
HEAD_ORDER
:
[
'
center'
,
'
height'
,
'
dim'
,
'
rot'
,
'
vel'
]
HEAD_DICT
:
{
'
center'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
height'
:
{
'
out_channels'
:
1
,
'
num_conv'
:
2
},
'
dim'
:
{
'
out_channels'
:
3
,
'
num_conv'
:
2
},
'
rot'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
vel'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
}
TARGET_ASSIGNER_CONFIG
:
FEATURE_MAP_STRIDE
:
8
DATASET
:
nuScenes
GAUSSIAN_OVERLAP
:
0.1
MIN_RADIUS
:
2
HUNGARIAN_ASSIGNER
:
cls_cost
:
{
'
gamma'
:
2.0
,
'
alpha'
:
0.25
,
'
weight'
:
0.15
}
reg_cost
:
{
'
weight'
:
0.25
}
iou_cost
:
{
'
weight'
:
0.25
}
LOSS_CONFIG
:
LOSS_WEIGHTS
:
{
'
cls_weight'
:
1.0
,
'
bbox_weight'
:
0.25
,
'
hm_weight'
:
1.0
,
'
code_weights'
:
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
0.2
,
0.2
]
}
LOSS_CLS
:
use_sigmoid
:
true
gamma
:
2.0
alpha
:
0.25
POST_PROCESSING
:
SCORE_THRESH
:
0.0
POST_CENTER_RANGE
:
[
-61.2
,
-61.2
,
-10.0
,
61.2
,
61.2
,
10.0
]
POST_PROCESSING
:
RECALL_THRESH_LIST
:
[
0.3
,
0.5
,
0.7
]
SCORE_THRESH
:
0.1
OUTPUT_RAW_SCORE
:
False
EVAL_METRIC
:
kitti
OPTIMIZATION
:
BATCH_SIZE_PER_GPU
:
3
NUM_EPOCHS
:
6
OPTIMIZER
:
adam_cosineanneal
LR
:
0.0001
WEIGHT_DECAY
:
0.01
MOMENTUM
:
0.9
BETAS
:
[
0.9
,
0.999
]
MOMS
:
[
0.9
,
0.8052631
]
PCT_START
:
0.4
WARMUP_ITER
:
500
DECAY_STEP_LIST
:
[
35
,
45
]
LR_WARMUP
:
False
WARMUP_EPOCH
:
1
GRAD_NORM_CLIP
:
35
LOSS_SCALE_FP16
:
32
\ No newline at end of file
tools/train_utils/optimization/__init__.py
View file @
8a64de5d
...
...
@@ -5,7 +5,7 @@ import torch.optim as optim
import
torch.optim.lr_scheduler
as
lr_sched
from
.fastai_optim
import
OptimWrapper
from
.learning_schedules_fastai
import
CosineWarmupLR
,
OneCycle
from
.learning_schedules_fastai
import
CosineWarmupLR
,
OneCycle
,
CosineAnnealing
def
build_optimizer
(
model
,
optim_cfg
):
...
...
@@ -16,7 +16,7 @@ def build_optimizer(model, optim_cfg):
model
.
parameters
(),
lr
=
optim_cfg
.
LR
,
weight_decay
=
optim_cfg
.
WEIGHT_DECAY
,
momentum
=
optim_cfg
.
MOMENTUM
)
elif
optim_cfg
.
OPTIMIZER
==
'adam_onecycle'
:
elif
optim_cfg
.
OPTIMIZER
in
[
'adam_onecycle'
,
'adam_cosineanneal'
]
:
def
children
(
m
:
nn
.
Module
):
return
list
(
m
.
children
())
...
...
@@ -52,6 +52,10 @@ def build_scheduler(optimizer, total_iters_each_epoch, total_epochs, last_epoch,
lr_scheduler
=
OneCycle
(
optimizer
,
total_steps
,
optim_cfg
.
LR
,
list
(
optim_cfg
.
MOMS
),
optim_cfg
.
DIV_FACTOR
,
optim_cfg
.
PCT_START
)
elif
optim_cfg
.
OPTIMIZER
==
'adam_cosineanneal'
:
lr_scheduler
=
CosineAnnealing
(
optimizer
,
total_steps
,
total_epochs
,
optim_cfg
.
LR
,
list
(
optim_cfg
.
MOMS
),
optim_cfg
.
PCT_START
,
optim_cfg
.
WARMUP_ITER
)
else
:
lr_scheduler
=
lr_sched
.
LambdaLR
(
optimizer
,
lr_lbmd
,
last_epoch
=
last_epoch
)
...
...
tools/train_utils/optimization/learning_schedules_fastai.py
View file @
8a64de5d
...
...
@@ -41,7 +41,7 @@ class LRSchedulerStep(object):
self
.
mom_phases
.
append
((
int
(
start
*
total_step
),
total_step
,
lambda_func
))
assert
self
.
mom_phases
[
0
][
0
]
==
0
def
step
(
self
,
step
):
def
step
(
self
,
step
,
epoch
=
None
):
for
start
,
end
,
func
in
self
.
lr_phases
:
if
step
>=
start
:
self
.
optimizer
.
lr
=
func
((
step
-
start
)
/
(
end
-
start
))
...
...
@@ -83,12 +83,60 @@ class CosineWarmupLR(lr_sched._LRScheduler):
self
.
eta_min
=
eta_min
super
(
CosineWarmupLR
,
self
).
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
def
get_lr
(
self
,
epoch
=
None
):
return
[
self
.
eta_min
+
(
base_lr
-
self
.
eta_min
)
*
(
1
-
math
.
cos
(
math
.
pi
*
self
.
last_epoch
/
self
.
T_max
))
/
2
for
base_lr
in
self
.
base_lrs
]
def
linear_warmup
(
end
,
lr_max
,
pct
):
k
=
(
1
-
pct
/
end
)
*
(
1
-
0.33333333
)
warmup_lr
=
lr_max
*
(
1
-
k
)
return
warmup_lr
class
CosineAnnealing
(
LRSchedulerStep
):
def
__init__
(
self
,
fai_optimizer
,
total_step
,
total_epoch
,
lr_max
,
moms
,
pct_start
,
warmup_iter
):
self
.
lr_max
=
lr_max
self
.
moms
=
moms
self
.
pct_start
=
pct_start
mom_phases
=
((
0
,
partial
(
annealing_cos
,
*
self
.
moms
)),
(
self
.
pct_start
,
partial
(
annealing_cos
,
*
self
.
moms
[::
-
1
])))
fai_optimizer
.
lr
,
fai_optimizer
.
mom
=
lr_max
,
self
.
moms
[
0
]
self
.
optimizer
=
fai_optimizer
self
.
total_step
=
total_step
self
.
warmup_iter
=
warmup_iter
self
.
total_epoch
=
total_epoch
self
.
mom_phases
=
[]
for
i
,
(
start
,
lambda_func
)
in
enumerate
(
mom_phases
):
if
len
(
self
.
mom_phases
)
!=
0
:
assert
self
.
mom_phases
[
-
1
][
0
]
<
start
if
isinstance
(
lambda_func
,
str
):
lambda_func
=
eval
(
lambda_func
)
if
i
<
len
(
mom_phases
)
-
1
:
self
.
mom_phases
.
append
((
int
(
start
*
total_step
),
int
(
mom_phases
[
i
+
1
][
0
]
*
total_step
),
lambda_func
))
else
:
self
.
mom_phases
.
append
((
int
(
start
*
total_step
),
total_step
,
lambda_func
))
assert
self
.
mom_phases
[
0
][
0
]
==
0
def
step
(
self
,
step
,
epoch
):
# update lr
if
step
<
self
.
warmup_iter
:
self
.
optimizer
.
lr
=
linear_warmup
(
self
.
warmup_iter
,
self
.
lr_max
,
step
)
else
:
target_lr
=
self
.
lr_max
*
0.001
cos_lr
=
annealing_cos
(
self
.
lr_max
,
target_lr
,
epoch
/
self
.
total_epoch
)
self
.
optimizer
.
lr
=
cos_lr
# update mom
for
start
,
end
,
func
in
self
.
mom_phases
:
if
step
>=
start
:
self
.
optimizer
.
mom
=
func
((
step
-
start
)
/
(
end
-
start
))
class
FakeOptim
:
def
__init__
(
self
):
self
.
lr
=
0
...
...
tools/train_utils/train_utils.py
View file @
8a64de5d
...
...
@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
data_timer
=
time
.
time
()
cur_data_time
=
data_timer
-
end
lr_scheduler
.
step
(
accumulated_iter
)
lr_scheduler
.
step
(
accumulated_iter
,
cur_epoch
)
try
:
cur_lr
=
float
(
optimizer
.
lr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment