Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
4dc18496
Commit
4dc18496
authored
May 07, 2023
by
chenshi3
Browse files
Add support for TransFusion-Lidar Head
parent
ad9c25c0
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1032 additions
and
17 deletions
+1032
-17
pcdet/models/backbones_2d/base_bev_backbone.py
pcdet/models/backbones_2d/base_bev_backbone.py
+1
-1
pcdet/models/backbones_3d/spconv_backbone.py
pcdet/models/backbones_3d/spconv_backbone.py
+12
-10
pcdet/models/dense_heads/target_assigner/hungarian_assigner.py
.../models/dense_heads/target_assigner/hungarian_assigner.py
+131
-0
pcdet/models/dense_heads/transfusion_head.py
pcdet/models/dense_heads/transfusion_head.py
+479
-0
pcdet/models/detectors/__init__.py
pcdet/models/detectors/__init__.py
+3
-1
pcdet/models/detectors/transfusion.py
pcdet/models/detectors/transfusion.py
+50
-0
pcdet/models/model_utils/transfusion_utils.py
pcdet/models/model_utils/transfusion_utils.py
+102
-0
pcdet/utils/loss_utils.py
pcdet/utils/loss_utils.py
+46
-1
tools/cfgs/nuscenes_models/cbgs_transfusion_lidar.yaml
tools/cfgs/nuscenes_models/cbgs_transfusion_lidar.yaml
+177
-0
tools/train.py
tools/train.py
+2
-1
tools/train_utils/optimization/__init__.py
tools/train_utils/optimization/__init__.py
+3
-2
tools/train_utils/train_utils.py
tools/train_utils/train_utils.py
+26
-1
No files found.
pcdet/models/backbones_2d/base_bev_backbone.py
View file @
4dc18496
...
...
@@ -46,7 +46,7 @@ class BaseBEVBackbone(nn.Module):
self
.
blocks
.
append
(
nn
.
Sequential
(
*
cur_layers
))
if
len
(
upsample_strides
)
>
0
:
stride
=
upsample_strides
[
idx
]
if
stride
>
=
1
:
if
stride
>
1
or
(
stride
==
1
and
not
self
.
model_cfg
.
get
(
'USE_CONV_FOR_NO_STRIDE'
,
False
))
:
self
.
deblocks
.
append
(
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
num_filters
[
idx
],
num_upsample_filters
[
idx
],
...
...
pcdet/models/backbones_3d/spconv_backbone.py
View file @
4dc18496
...
...
@@ -30,11 +30,12 @@ def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stri
class
SparseBasicBlock
(
spconv
.
SparseModule
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
norm_fn
=
None
,
downsample
=
None
,
indice_key
=
None
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
bias
=
None
,
norm_fn
=
None
,
downsample
=
None
,
indice_key
=
None
):
super
(
SparseBasicBlock
,
self
).
__init__
()
assert
norm_fn
is
not
None
bias
=
norm_fn
is
not
None
if
bias
is
None
:
bias
=
norm_fn
is
not
None
self
.
conv1
=
spconv
.
SubMConv3d
(
inplanes
,
planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
bias
,
indice_key
=
indice_key
)
...
...
@@ -184,6 +185,7 @@ class VoxelResBackBone8x(nn.Module):
def
__init__
(
self
,
model_cfg
,
input_channels
,
grid_size
,
**
kwargs
):
super
().
__init__
()
self
.
model_cfg
=
model_cfg
use_bias
=
self
.
model_cfg
.
get
(
'USE_BIAS'
,
None
)
norm_fn
=
partial
(
nn
.
BatchNorm1d
,
eps
=
1e-3
,
momentum
=
0.01
)
self
.
sparse_shape
=
grid_size
[::
-
1
]
+
[
1
,
0
,
0
]
...
...
@@ -196,29 +198,29 @@ class VoxelResBackBone8x(nn.Module):
block
=
post_act_block
self
.
conv1
=
spconv
.
SparseSequential
(
SparseBasicBlock
(
16
,
16
,
norm_fn
=
norm_fn
,
indice_key
=
'res1'
),
SparseBasicBlock
(
16
,
16
,
norm_fn
=
norm_fn
,
indice_key
=
'res1'
),
SparseBasicBlock
(
16
,
16
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res1'
),
SparseBasicBlock
(
16
,
16
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res1'
),
)
self
.
conv2
=
spconv
.
SparseSequential
(
# [1600, 1408, 41] <- [800, 704, 21]
block
(
16
,
32
,
3
,
norm_fn
=
norm_fn
,
stride
=
2
,
padding
=
1
,
indice_key
=
'spconv2'
,
conv_type
=
'spconv'
),
SparseBasicBlock
(
32
,
32
,
norm_fn
=
norm_fn
,
indice_key
=
'res2'
),
SparseBasicBlock
(
32
,
32
,
norm_fn
=
norm_fn
,
indice_key
=
'res2'
),
SparseBasicBlock
(
32
,
32
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res2'
),
SparseBasicBlock
(
32
,
32
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res2'
),
)
self
.
conv3
=
spconv
.
SparseSequential
(
# [800, 704, 21] <- [400, 352, 11]
block
(
32
,
64
,
3
,
norm_fn
=
norm_fn
,
stride
=
2
,
padding
=
1
,
indice_key
=
'spconv3'
,
conv_type
=
'spconv'
),
SparseBasicBlock
(
64
,
64
,
norm_fn
=
norm_fn
,
indice_key
=
'res3'
),
SparseBasicBlock
(
64
,
64
,
norm_fn
=
norm_fn
,
indice_key
=
'res3'
),
SparseBasicBlock
(
64
,
64
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res3'
),
SparseBasicBlock
(
64
,
64
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res3'
),
)
self
.
conv4
=
spconv
.
SparseSequential
(
# [400, 352, 11] <- [200, 176, 5]
block
(
64
,
128
,
3
,
norm_fn
=
norm_fn
,
stride
=
2
,
padding
=
(
0
,
1
,
1
),
indice_key
=
'spconv4'
,
conv_type
=
'spconv'
),
SparseBasicBlock
(
128
,
128
,
norm_fn
=
norm_fn
,
indice_key
=
'res4'
),
SparseBasicBlock
(
128
,
128
,
norm_fn
=
norm_fn
,
indice_key
=
'res4'
),
SparseBasicBlock
(
128
,
128
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res4'
),
SparseBasicBlock
(
128
,
128
,
bias
=
use_bias
,
norm_fn
=
norm_fn
,
indice_key
=
'res4'
),
)
last_pad
=
0
...
...
pcdet/models/dense_heads/target_assigner/hungarian_assigner.py
0 → 100644
View file @
4dc18496
import
torch
from
scipy.optimize
import
linear_sum_assignment
from
pcdet.ops.iou3d_nms
import
iou3d_nms_cuda
def
height_overlaps
(
boxes1
,
boxes2
):
"""
Calculate height overlaps of two boxes.
"""
boxes1_top_height
=
(
boxes1
[:,
2
]
+
boxes1
[:,
5
]).
view
(
-
1
,
1
)
boxes1_bottom_height
=
boxes1
[:,
2
].
view
(
-
1
,
1
)
boxes2_top_height
=
(
boxes2
[:,
2
]
+
boxes2
[:,
5
]).
view
(
1
,
-
1
)
boxes2_bottom_height
=
boxes2
[:,
2
].
view
(
1
,
-
1
)
heighest_of_bottom
=
torch
.
max
(
boxes1_bottom_height
,
boxes2_bottom_height
)
lowest_of_top
=
torch
.
min
(
boxes1_top_height
,
boxes2_top_height
)
overlaps_h
=
torch
.
clamp
(
lowest_of_top
-
heighest_of_bottom
,
min
=
0
)
return
overlaps_h
def
overlaps
(
boxes1
,
boxes2
):
"""
Calculate 3D overlaps of two boxes.
"""
rows
=
len
(
boxes1
)
cols
=
len
(
boxes2
)
if
rows
*
cols
==
0
:
return
boxes1
.
new
(
rows
,
cols
)
# height overlap
overlaps_h
=
height_overlaps
(
boxes1
,
boxes2
)
boxes1_bev
=
boxes1
[:,:
7
]
boxes2_bev
=
boxes2
[:,:
7
]
# bev overlap
overlaps_bev
=
boxes1_bev
.
new_zeros
(
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])
).
cuda
()
# (N, M)
iou3d_nms_cuda
.
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
boxes2_bev
.
contiguous
().
cuda
(),
overlaps_bev
)
# 3d overlaps
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
volume1
=
(
boxes1
[:,
3
]
*
boxes1
[:,
4
]
*
boxes1
[:,
5
]).
view
(
-
1
,
1
)
volume2
=
(
boxes2
[:,
3
]
*
boxes2
[:,
4
]
*
boxes2
[:,
5
]).
view
(
1
,
-
1
)
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
+
volume2
-
overlaps_3d
,
min
=
1e-8
)
return
iou3d
class
HungarianAssigner3D
:
def
__init__
(
self
,
cls_cost
,
reg_cost
,
iou_cost
):
self
.
cls_cost
=
cls_cost
self
.
reg_cost
=
reg_cost
self
.
iou_cost
=
iou_cost
def
focal_loss_cost
(
self
,
cls_pred
,
gt_labels
):
weight
=
self
.
cls_cost
.
get
(
'weight'
,
0.15
)
alpha
=
self
.
cls_cost
.
get
(
'alpha'
,
0.25
)
gamma
=
self
.
cls_cost
.
get
(
'gamma'
,
2.0
)
eps
=
self
.
cls_cost
.
get
(
'eps'
,
1e-12
)
cls_pred
=
cls_pred
.
sigmoid
()
neg_cost
=
-
(
1
-
cls_pred
+
eps
).
log
()
*
(
1
-
alpha
)
*
cls_pred
.
pow
(
gamma
)
pos_cost
=
-
(
cls_pred
+
eps
).
log
()
*
alpha
*
(
1
-
cls_pred
).
pow
(
gamma
)
cls_cost
=
pos_cost
[:,
gt_labels
]
-
neg_cost
[:,
gt_labels
]
return
cls_cost
*
weight
def
bevbox_cost
(
self
,
bboxes
,
gt_bboxes
,
point_cloud_range
):
weight
=
self
.
reg_cost
.
get
(
'weight'
,
0.25
)
pc_start
=
bboxes
.
new
(
point_cloud_range
[
0
:
2
])
pc_range
=
bboxes
.
new
(
point_cloud_range
[
3
:
5
])
-
bboxes
.
new
(
point_cloud_range
[
0
:
2
])
# normalize the box center to [0, 1]
normalized_bboxes_xy
=
(
bboxes
[:,
:
2
]
-
pc_start
)
/
pc_range
normalized_gt_bboxes_xy
=
(
gt_bboxes
[:,
:
2
]
-
pc_start
)
/
pc_range
reg_cost
=
torch
.
cdist
(
normalized_bboxes_xy
,
normalized_gt_bboxes_xy
,
p
=
1
)
return
reg_cost
*
weight
def
iou3d_cost
(
self
,
bboxes
,
gt_bboxes
):
iou
=
overlaps
(
bboxes
,
gt_bboxes
)
weight
=
self
.
iou_cost
.
get
(
'weight'
,
0.25
)
iou_cost
=
-
iou
return
iou_cost
*
weight
,
iou
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_labels
,
cls_pred
,
point_cloud_range
):
num_gts
,
num_bboxes
=
gt_bboxes
.
size
(
0
),
bboxes
.
size
(
0
)
# 1. assign -1 by default
assigned_gt_inds
=
bboxes
.
new_full
((
num_bboxes
,),
-
1
,
dtype
=
torch
.
long
)
assigned_labels
=
bboxes
.
new_full
((
num_bboxes
,),
-
1
,
dtype
=
torch
.
long
)
if
num_gts
==
0
or
num_bboxes
==
0
:
# No ground truth or boxes, return empty assignment
if
num_gts
==
0
:
# No ground truth, assign all to background
assigned_gt_inds
[:]
=
0
return
num_gts
,
assigned_gt_inds
,
max_overlaps
,
assigned_labels
# 2. compute the weighted costs
cls_cost
=
self
.
focal_loss_cost
(
cls_pred
[
0
].
T
,
gt_labels
)
reg_cost
=
self
.
bevbox_cost
(
bboxes
,
gt_bboxes
,
point_cloud_range
)
iou_cost
,
iou
=
self
.
iou3d_cost
(
bboxes
,
gt_bboxes
)
# weighted sum of above three costs
cost
=
cls_cost
+
reg_cost
+
iou_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost
=
cost
.
detach
().
cpu
()
matched_row_inds
,
matched_col_inds
=
linear_sum_assignment
(
cost
)
matched_row_inds
=
torch
.
from_numpy
(
matched_row_inds
).
to
(
bboxes
.
device
)
matched_col_inds
=
torch
.
from_numpy
(
matched_col_inds
).
to
(
bboxes
.
device
)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds
[:]
=
0
# assign foregrounds based on matching results
assigned_gt_inds
[
matched_row_inds
]
=
matched_col_inds
+
1
assigned_labels
[
matched_row_inds
]
=
gt_labels
[
matched_col_inds
]
max_overlaps
=
torch
.
zeros_like
(
iou
.
max
(
1
).
values
)
max_overlaps
[
matched_row_inds
]
=
iou
[
matched_row_inds
,
matched_col_inds
]
return
assigned_gt_inds
,
max_overlaps
\ No newline at end of file
pcdet/models/dense_heads/transfusion_head.py
0 → 100644
View file @
4dc18496
This diff is collapsed.
Click to expand it.
pcdet/models/detectors/__init__.py
View file @
4dc18496
...
...
@@ -13,6 +13,7 @@ from .mppnet import MPPNet
from
.mppnet_e2e
import
MPPNetE2E
from
.pillarnet
import
PillarNet
from
.voxelnext
import
VoxelNeXt
from
.transfusion
import
TransFusion
__all__
=
{
'Detector3DTemplate'
:
Detector3DTemplate
,
...
...
@@ -30,7 +31,8 @@ __all__ = {
'MPPNet'
:
MPPNet
,
'MPPNetE2E'
:
MPPNetE2E
,
'PillarNet'
:
PillarNet
,
'VoxelNeXt'
:
VoxelNeXt
'VoxelNeXt'
:
VoxelNeXt
,
'TransFusion'
:
TransFusion
,
}
...
...
pcdet/models/detectors/transfusion.py
0 → 100644
View file @
4dc18496
from
.detector3d_template
import
Detector3DTemplate
class
TransFusion
(
Detector3DTemplate
):
def
__init__
(
self
,
model_cfg
,
num_class
,
dataset
):
super
().
__init__
(
model_cfg
=
model_cfg
,
num_class
=
num_class
,
dataset
=
dataset
)
self
.
module_list
=
self
.
build_networks
()
def
forward
(
self
,
batch_dict
):
for
cur_module
in
self
.
module_list
:
batch_dict
=
cur_module
(
batch_dict
)
if
self
.
training
:
loss
,
tb_dict
,
disp_dict
=
self
.
get_training_loss
(
batch_dict
)
ret_dict
=
{
'loss'
:
loss
}
return
ret_dict
,
tb_dict
,
disp_dict
else
:
pred_dicts
,
recall_dicts
=
self
.
post_processing
(
batch_dict
)
return
pred_dicts
,
recall_dicts
def
get_training_loss
(
self
,
batch_dict
):
disp_dict
=
{}
loss_trans
,
tb_dict
=
batch_dict
[
'loss'
],
batch_dict
[
'tb_dict'
]
tb_dict
=
{
'loss_trans'
:
loss_trans
.
item
(),
**
tb_dict
}
loss
=
loss_trans
return
loss
,
tb_dict
,
disp_dict
def
post_processing
(
self
,
batch_dict
):
post_process_cfg
=
self
.
model_cfg
.
POST_PROCESSING
batch_size
=
batch_dict
[
'batch_size'
]
final_pred_dict
=
batch_dict
[
'final_box_dicts'
]
recall_dict
=
{}
for
index
in
range
(
batch_size
):
pred_boxes
=
final_pred_dict
[
index
][
'pred_boxes'
]
recall_dict
=
self
.
generate_recall_record
(
box_preds
=
pred_boxes
,
recall_dict
=
recall_dict
,
batch_index
=
index
,
data_dict
=
batch_dict
,
thresh_list
=
post_process_cfg
.
RECALL_THRESH_LIST
)
return
final_pred_dict
,
recall_dict
pcdet/models/model_utils/transfusion_utils.py
0 → 100644
View file @
4dc18496
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
def
clip_sigmoid
(
x
,
eps
=
1e-4
):
y
=
torch
.
clamp
(
x
.
sigmoid_
(),
min
=
eps
,
max
=
1
-
eps
)
return
y
class
PositionEmbeddingLearned
(
nn
.
Module
):
"""
Absolute pos embedding, learned.
"""
def
__init__
(
self
,
input_channel
,
num_pos_feats
=
288
):
super
().
__init__
()
self
.
position_embedding_head
=
nn
.
Sequential
(
nn
.
Conv1d
(
input_channel
,
num_pos_feats
,
kernel_size
=
1
),
nn
.
BatchNorm1d
(
num_pos_feats
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv1d
(
num_pos_feats
,
num_pos_feats
,
kernel_size
=
1
))
def
forward
(
self
,
xyz
):
xyz
=
xyz
.
transpose
(
1
,
2
).
contiguous
()
position_embedding
=
self
.
position_embedding_head
(
xyz
)
return
position_embedding
class
TransformerDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
nhead
,
dim_feedforward
=
2048
,
dropout
=
0.1
,
activation
=
"relu"
,
self_posembed
=
None
,
cross_posembed
=
None
,
cross_only
=
False
):
super
().
__init__
()
self
.
cross_only
=
cross_only
if
not
self
.
cross_only
:
self
.
self_attn
=
nn
.
MultiheadAttention
(
d_model
,
nhead
,
dropout
=
dropout
)
self
.
multihead_attn
=
nn
.
MultiheadAttention
(
d_model
,
nhead
,
dropout
=
dropout
)
# Implementation of Feedforward model
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
self
.
norm3
=
nn
.
LayerNorm
(
d_model
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
def
_get_activation_fn
(
activation
):
"""Return an activation function given a string"""
if
activation
==
"relu"
:
return
F
.
relu
if
activation
==
"gelu"
:
return
F
.
gelu
if
activation
==
"glu"
:
return
F
.
glu
raise
RuntimeError
(
F
"activation should be relu/gelu, not
{
activation
}
."
)
self
.
activation
=
_get_activation_fn
(
activation
)
self
.
self_posembed
=
self_posembed
self
.
cross_posembed
=
cross_posembed
def
with_pos_embed
(
self
,
tensor
,
pos_embed
):
return
tensor
if
pos_embed
is
None
else
tensor
+
pos_embed
def
forward
(
self
,
query
,
key
,
query_pos
,
key_pos
,
key_padding_mask
=
None
,
attn_mask
=
None
):
# NxCxP to PxNxC
if
self
.
self_posembed
is
not
None
:
query_pos_embed
=
self
.
self_posembed
(
query_pos
).
permute
(
2
,
0
,
1
)
else
:
query_pos_embed
=
None
if
self
.
cross_posembed
is
not
None
:
key_pos_embed
=
self
.
cross_posembed
(
key_pos
).
permute
(
2
,
0
,
1
)
else
:
key_pos_embed
=
None
query
=
query
.
permute
(
2
,
0
,
1
)
key
=
key
.
permute
(
2
,
0
,
1
)
if
not
self
.
cross_only
:
q
=
k
=
v
=
self
.
with_pos_embed
(
query
,
query_pos_embed
)
query2
=
self
.
self_attn
(
q
,
k
,
value
=
v
)[
0
]
query
=
query
+
self
.
dropout1
(
query2
)
query
=
self
.
norm1
(
query
)
query2
=
self
.
multihead_attn
(
query
=
self
.
with_pos_embed
(
query
,
query_pos_embed
),
key
=
self
.
with_pos_embed
(
key
,
key_pos_embed
),
value
=
self
.
with_pos_embed
(
key
,
key_pos_embed
),
key_padding_mask
=
key_padding_mask
,
attn_mask
=
attn_mask
)[
0
]
query
=
query
+
self
.
dropout2
(
query2
)
query
=
self
.
norm2
(
query
)
query2
=
self
.
linear2
(
self
.
dropout
(
self
.
activation
(
self
.
linear1
(
query
))))
query
=
query
+
self
.
dropout3
(
query2
)
query
=
self
.
norm3
(
query
)
# NxCxP to PxNxC
query
=
query
.
permute
(
1
,
2
,
0
)
return
query
pcdet/utils/loss_utils.py
View file @
4dc18496
...
...
@@ -560,4 +560,49 @@ class IouRegLossSparse(nn.Module):
loss
+=
(
1.
-
iou
).
sum
()
loss
=
loss
/
(
mask
.
sum
()
+
1e-4
)
return
loss
\ No newline at end of file
return
loss
class
L1Loss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
L1Loss
,
self
).
__init__
()
def
forward
(
self
,
pred
,
target
):
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
target
.
size
()
loss
=
torch
.
abs
(
pred
-
target
)
return
loss
class
GaussianFocalLoss
(
nn
.
Module
):
"""GaussianFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Code is modified from `kp_utils.py
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
not 0/1 binary target.
Args:
alpha (float): Power of prediction.
gamma (float): Power of target for negative samples.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def
__init__
(
self
,
alpha
=
2.0
,
gamma
=
4.0
):
super
(
GaussianFocalLoss
,
self
).
__init__
()
self
.
alpha
=
alpha
self
.
gamma
=
gamma
def
forward
(
self
,
pred
,
target
):
eps
=
1e-12
pos_weights
=
target
.
eq
(
1
)
neg_weights
=
(
1
-
target
).
pow
(
self
.
gamma
)
pos_loss
=
-
(
pred
+
eps
).
log
()
*
(
1
-
pred
).
pow
(
self
.
alpha
)
*
pos_weights
neg_loss
=
-
(
1
-
pred
+
eps
).
log
()
*
pred
.
pow
(
self
.
alpha
)
*
neg_weights
return
pos_loss
+
neg_loss
\ No newline at end of file
tools/cfgs/nuscenes_models/cbgs_transfusion_lidar.yaml
0 → 100644
View file @
4dc18496
CLASS_NAMES
:
[
'
car'
,
'
truck'
,
'
construction_vehicle'
,
'
bus'
,
'
trailer'
,
'
barrier'
,
'
motorcycle'
,
'
bicycle'
,
'
pedestrian'
,
'
traffic_cone'
]
DATA_CONFIG
:
_BASE_CONFIG_
:
cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE
:
[
-54.0
,
-54.0
,
-5.0
,
54.0
,
54.0
,
3.0
]
# sc TODO: just for debug
INFO_PATH
:
{
'
train'
:
[
nuscenes_infos_10sweeps_train_with_cam_2d.pkl
],
'
test'
:
[
nuscenes_infos_10sweeps_val_with_cam_2d.pkl
],
}
DATA_AUGMENTOR
:
DISABLE_AUG_LIST
:
[
'
placeholder'
]
AUG_CONFIG_LIST
:
-
NAME
:
gt_sampling
DB_INFO_PATH
:
-
nuscenes_dbinfos_10sweeps_withvelo.pkl
PREPARE
:
{
filter_by_min_points
:
[
'
car:5'
,
'
truck:5'
,
'
construction_vehicle:5'
,
'
bus:5'
,
'
trailer:5'
,
'
barrier:5'
,
'
motorcycle:5'
,
'
bicycle:5'
,
'
pedestrian:5'
,
'
traffic_cone:5'
],
}
SAMPLE_GROUPS
:
[
'
car:2'
,
'
truck:3'
,
'
construction_vehicle:7'
,
'
bus:4'
,
'
trailer:6'
,
'
barrier:2'
,
'
motorcycle:6'
,
'
bicycle:6'
,
'
pedestrian:2'
,
'
traffic_cone:2'
]
NUM_POINT_FEATURES
:
5
DATABASE_WITH_FAKELIDAR
:
False
REMOVE_EXTRA_WIDTH
:
[
0.0
,
0.0
,
0.0
]
LIMIT_WHOLE_SCENE
:
True
-
NAME
:
random_world_flip
ALONG_AXIS_LIST
:
[
'
x'
,
'
y'
]
-
NAME
:
random_world_rotation
WORLD_ROT_ANGLE
:
[
-0.78539816
,
0.78539816
]
-
NAME
:
random_world_scaling
WORLD_SCALE_RANGE
:
[
0.9
,
1.1
]
-
NAME
:
random_world_translation
NOISE_TRANSLATE_STD
:
[
0.5
,
0.5
,
0.5
]
DATA_PROCESSOR
:
-
NAME
:
mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES
:
True
-
NAME
:
shuffle_points
SHUFFLE_ENABLED
:
{
'
train'
:
True
,
'
test'
:
True
}
-
NAME
:
transform_points_to_voxels
VOXEL_SIZE
:
[
0.075
,
0.075
,
0.2
]
MAX_POINTS_PER_VOXEL
:
10
MAX_NUMBER_OF_VOXELS
:
{
'
train'
:
120000
,
'
test'
:
160000
}
MODEL
:
NAME
:
TransFusion
VFE
:
NAME
:
MeanVFE
BACKBONE_3D
:
NAME
:
VoxelResBackBone8x
USE_BIAS
:
False
MAP_TO_BEV
:
NAME
:
HeightCompression
NUM_BEV_FEATURES
:
256
BACKBONE_2D
:
NAME
:
BaseBEVBackbone
LAYER_NUMS
:
[
5
,
5
]
LAYER_STRIDES
:
[
1
,
2
]
NUM_FILTERS
:
[
128
,
256
]
UPSAMPLE_STRIDES
:
[
1
,
2
]
NUM_UPSAMPLE_FILTERS
:
[
256
,
256
]
USE_CONV_FOR_NO_STRIDE
:
True
DENSE_HEAD
:
CLASS_AGNOSTIC
:
False
NAME
:
TransFusionHead
USE_BIAS_BEFORE_NORM
:
False
NUM_PROPOSALS
:
200
HIDDEN_CHANNEL
:
128
NUM_CLASSES
:
10
NUM_HEADS
:
8
NMS_KERNEL_SIZE
:
3
FFN_CHANNEL
:
256
DROPOUT
:
0.1
BN_MOMENTUM
:
0.1
ACTIVATION
:
relu
NUM_HM_CONV
:
2
SEPARATE_HEAD_CFG
:
HEAD_ORDER
:
[
'
center'
,
'
height'
,
'
dim'
,
'
rot'
,
'
vel'
]
HEAD_DICT
:
{
'
center'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
height'
:
{
'
out_channels'
:
1
,
'
num_conv'
:
2
},
'
dim'
:
{
'
out_channels'
:
3
,
'
num_conv'
:
2
},
'
rot'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
vel'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
}
TARGET_ASSIGNER_CONFIG
:
FEATURE_MAP_STRIDE
:
8
DATASET
:
nuScenes
GAUSSIAN_OVERLAP
:
0.1
MIN_RADIUS
:
2
HUNGARIAN_ASSIGNER
:
cls_cost
:
{
'
gamma'
:
2.0
,
'
alpha'
:
0.25
,
'
weight'
:
0.15
}
reg_cost
:
{
'
weight'
:
0.25
}
iou_cost
:
{
'
weight'
:
0.25
}
LOSS_CONFIG
:
LOSS_WEIGHTS
:
{
'
cls_weight'
:
1.0
,
'
bbox_weight'
:
0.25
,
'
hm_weight'
:
1.0
,
'
code_weights'
:
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
0.2
,
0.2
]
}
LOSS_CLS
:
use_sigmoid
:
true
gamma
:
2.0
alpha
:
0.25
POST_PROCESSING
:
SCORE_THRESH
:
0.0
POST_CENTER_RANGE
:
[
-61.2
,
-61.2
,
-10.0
,
61.2
,
61.2
,
10.0
]
POST_PROCESSING
:
RECALL_THRESH_LIST
:
[
0.3
,
0.5
,
0.7
]
SCORE_THRESH
:
0.1
OUTPUT_RAW_SCORE
:
False
EVAL_METRIC
:
kitti
OPTIMIZATION
:
BATCH_SIZE_PER_GPU
:
4
NUM_EPOCHS
:
20
OPTIMIZER
:
adam_onecycle
LR
:
0.001
WEIGHT_DECAY
:
0.01
MOMENTUM
:
0.9
BETAS
:
[
0.9
,
0.999
]
MOMS
:
[
0.9
,
0.8052631
]
PCT_START
:
0.4
DIV_FACTOR
:
10
DECAY_STEP_LIST
:
[
35
,
45
]
LR_DECAY
:
0.1
LR_CLIP
:
0.0000001
LR_WARMUP
:
False
WARMUP_EPOCH
:
1
GRAD_NORM_CLIP
:
35
HOOK
:
DisableAugmentationHook
:
DISABLE_AUG_LIST
:
[
'
gt_sampling'
]
NUM_LAST_EPOCHS
:
5
\ No newline at end of file
tools/train.py
View file @
4dc18496
...
...
@@ -195,7 +195,8 @@ def main():
ckpt_save_time_interval
=
args
.
ckpt_save_time_interval
,
use_logger_to_record
=
not
args
.
use_tqdm_to_record
,
show_gpu_stat
=
not
args
.
wo_gpu_stat
,
use_amp
=
args
.
use_amp
use_amp
=
args
.
use_amp
,
cfg
=
cfg
)
if
hasattr
(
train_set
,
'use_shared_memory'
)
and
train_set
.
use_shared_memory
:
...
...
tools/train_utils/optimization/__init__.py
View file @
4dc18496
...
...
@@ -25,8 +25,9 @@ def build_optimizer(model, optim_cfg):
flatten_model
=
lambda
m
:
sum
(
map
(
flatten_model
,
m
.
children
()),
[])
if
num_children
(
m
)
else
[
m
]
get_layer_groups
=
lambda
m
:
[
nn
.
Sequential
(
*
flatten_model
(
m
))]
optimizer_func
=
partial
(
optim
.
Adam
,
betas
=
(
0.9
,
0.99
))
betas
=
optim_cfg
.
get
(
'BETAS'
,
(
0.9
,
0.99
))
betas
=
tuple
(
betas
)
optimizer_func
=
partial
(
optim
.
Adam
,
betas
=
betas
)
optimizer
=
OptimWrapper
.
create
(
optimizer_func
,
3e-3
,
get_layer_groups
(
model
),
wd
=
optim_cfg
.
WEIGHT_DECAY
,
true_wd
=
True
,
bn_wd
=
True
)
...
...
tools/train_utils/train_utils.py
View file @
4dc18496
...
...
@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
start_epoch
,
total_epochs
,
start_iter
,
rank
,
tb_log
,
ckpt_save_dir
,
train_sampler
=
None
,
lr_warmup_scheduler
=
None
,
ckpt_save_interval
=
1
,
max_ckpt_save_num
=
50
,
merge_all_iters_to_one_epoch
=
False
,
use_amp
=
False
,
use_logger_to_record
=
False
,
logger
=
None
,
logger_iter_interval
=
None
,
ckpt_save_time_interval
=
None
,
show_gpu_stat
=
False
):
use_logger_to_record
=
False
,
logger
=
None
,
logger_iter_interval
=
None
,
ckpt_save_time_interval
=
None
,
show_gpu_stat
=
False
,
cfg
=
None
):
accumulated_iter
=
start_iter
# use for disable data augmentation hook
hook_config
=
cfg
.
get
(
'HOOK'
,
None
)
augment_disable_flag
=
False
with
tqdm
.
trange
(
start_epoch
,
total_epochs
,
desc
=
'epochs'
,
dynamic_ncols
=
True
,
leave
=
(
rank
==
0
))
as
tbar
:
total_it_each_epoch
=
len
(
train_loader
)
if
merge_all_iters_to_one_epoch
:
...
...
@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
cur_scheduler
=
lr_warmup_scheduler
else
:
cur_scheduler
=
lr_scheduler
augment_disable_flag
=
disable_augmentation_hook
(
hook_config
,
dataloader_iter
,
total_epochs
,
cur_epoch
,
cfg
,
augment_disable_flag
,
logger
)
accumulated_iter
=
train_one_epoch
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
=
cur_scheduler
,
...
...
@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'):
torch
.
save
(
state
,
filename
,
_use_new_zipfile_serialization
=
False
)
else
:
torch
.
save
(
state
,
filename
)
def
disable_augmentation_hook
(
hook_config
,
dataloader
,
total_epochs
,
cur_epoch
,
cfg
,
flag
,
logger
):
"""
This hook turns off the data augmentation during training.
"""
if
hook_config
is
not
None
:
DisableAugmentationHook
=
hook_config
.
get
(
'DisableAugmentationHook'
,
None
)
if
DisableAugmentationHook
is
not
None
:
num_last_epochs
=
DisableAugmentationHook
.
NUM_LAST_EPOCHS
if
(
total_epochs
-
num_last_epochs
)
<=
cur_epoch
and
not
flag
:
DISABLE_AUG_LIST
=
DisableAugmentationHook
.
DISABLE_AUG_LIST
dataset_cfg
=
cfg
.
DATA_CONFIG
logger
.
info
(
f
'Disable augmentations:
{
DISABLE_AUG_LIST
}
'
)
dataset_cfg
.
DATA_AUGMENTOR
.
DISABLE_AUG_LIST
=
DISABLE_AUG_LIST
dataloader
.
_dataset
.
data_augmentor
.
disableAugmentation
(
dataset_cfg
.
DATA_AUGMENTOR
)
flag
=
True
return
flag
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment