Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
72c608ce
"pcdet/ops/pointnet2/pointnet2_stack/src/voxel_query.cpp" did not exist on "05009423dc53f53f7b4ea07ae3c21070dcf2ed29"
Commit
72c608ce
authored
Jun 11, 2023
by
chenshi3
Browse files
Add support for DSVT
parent
02ac3e17
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
514 additions
and
1 deletion
+514
-1
pcdet/ops/iou3d_nms/src/iou3d_nms.h
pcdet/ops/iou3d_nms/src/iou3d_nms.h
+1
-0
pcdet/ops/iou3d_nms/src/iou3d_nms_api.cpp
pcdet/ops/iou3d_nms/src/iou3d_nms_api.cpp
+1
-0
pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu
pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu
+26
-0
pcdet/utils/box_utils.py
pcdet/utils/box_utils.py
+54
-0
pcdet/utils/loss_utils.py
pcdet/utils/loss_utils.py
+42
-1
setup.py
setup.py
+8
-0
tools/cfgs/waymo_models/dsvt_pillar.yaml
tools/cfgs/waymo_models/dsvt_pillar.yaml
+190
-0
tools/cfgs/waymo_models/dsvt_voxel.yaml
tools/cfgs/waymo_models/dsvt_voxel.yaml
+192
-0
No files found.
pcdet/ops/iou3d_nms/src/iou3d_nms.h
View file @
72c608ce
...
...
@@ -9,6 +9,7 @@
int
boxes_aligned_overlap_bev_gpu
(
at
::
Tensor
boxes_a
,
at
::
Tensor
boxes_b
,
at
::
Tensor
ans_overlap
);
int
boxes_overlap_bev_gpu
(
at
::
Tensor
boxes_a
,
at
::
Tensor
boxes_b
,
at
::
Tensor
ans_overlap
);
int
paired_boxes_overlap_bev_gpu
(
at
::
Tensor
boxes_a
,
at
::
Tensor
boxes_b
,
at
::
Tensor
ans_overlap
);
int
boxes_iou_bev_gpu
(
at
::
Tensor
boxes_a
,
at
::
Tensor
boxes_b
,
at
::
Tensor
ans_iou
);
int
nms_gpu
(
at
::
Tensor
boxes
,
at
::
Tensor
keep
,
float
nms_overlap_thresh
);
int
nms_normal_gpu
(
at
::
Tensor
boxes
,
at
::
Tensor
keep
,
float
nms_overlap_thresh
);
...
...
pcdet/ops/iou3d_nms/src/iou3d_nms_api.cpp
View file @
72c608ce
...
...
@@ -11,6 +11,7 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"boxes_aligned_overlap_bev_gpu"
,
&
boxes_aligned_overlap_bev_gpu
,
"aligned oriented boxes overlap"
);
m
.
def
(
"boxes_overlap_bev_gpu"
,
&
boxes_overlap_bev_gpu
,
"oriented boxes overlap"
);
m
.
def
(
"paired_boxes_overlap_bev_gpu"
,
&
paired_boxes_overlap_bev_gpu
,
"oriented boxes overlap"
);
m
.
def
(
"boxes_iou_bev_gpu"
,
&
boxes_iou_bev_gpu
,
"oriented boxes iou"
);
m
.
def
(
"nms_gpu"
,
&
nms_gpu
,
"oriented nms gpu"
);
m
.
def
(
"nms_normal_gpu"
,
&
nms_normal_gpu
,
"nms gpu"
);
...
...
pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu
View file @
72c608ce
...
...
@@ -248,6 +248,21 @@ __global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a, cons
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
}
__global__
void
paired_boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
const
int
idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
idx
>=
num_a
){
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
idx
*
7
;
const
float
*
cur_box_b
=
boxes_b
+
idx
*
7
;
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
// printf("idx=%d, box_a=(%.3f, %.3f, %.3f, ), box_b=(%.3f, %.3f, %.3f, ), overlap=%.5f\n", idx, cur_box_a[0], cur_box_a[1], cur_box_a[2], cur_box_b[0], cur_box_b[1], cur_box_b[2], s_overlap);
ans_overlap
[
idx
]
=
s_overlap
;
}
__global__
void
boxes_aligned_overlap_kernel
(
const
int
num_box
,
const
float
*
boxes_a
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
...
...
@@ -399,6 +414,17 @@ void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b
#endif
}
void
PairedBoxesOverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
dim3
blocks
(
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
paired_boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
void
boxesalignedoverlapLauncher
(
const
int
num_box
,
const
float
*
boxes_a
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
dim3
blocks
(
DIVUP
(
num_box
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
...
...
pcdet/utils/box_utils.py
View file @
72c608ce
...
...
@@ -384,3 +384,57 @@ def pairwise_iou(boxes1, boxes2) -> torch.Tensor:
torch
.
zeros
(
1
,
dtype
=
inter
.
dtype
,
device
=
inter
.
device
),
)
return
iou
def
center_to_corner2d
(
center
,
dim
):
corners_norm
=
torch
.
tensor
([[
-
0.5
,
-
0.5
],
[
-
0.5
,
0.5
],
[
0.5
,
0.5
],
[
0.5
,
-
0.5
]],
device
=
dim
.
device
).
type_as
(
center
)
# (4, 2)
corners
=
dim
.
view
([
-
1
,
1
,
2
])
*
corners_norm
.
view
([
1
,
4
,
2
])
# (N, 4, 2)
corners
=
corners
+
center
.
view
(
-
1
,
1
,
2
)
return
corners
def
bbox3d_overlaps_diou
(
pred_boxes
,
gt_boxes
):
"""
https://github.com/agent-sgs/PillarNet/blob/master/det3d/core/utils/center_utils.py
Args:
pred_boxes (N, 7):
gt_boxes (N, 7):
Returns:
_type_: _description_
"""
assert
pred_boxes
.
shape
[
0
]
==
gt_boxes
.
shape
[
0
]
qcorners
=
center_to_corner2d
(
pred_boxes
[:,
:
2
],
pred_boxes
[:,
3
:
5
])
# (N, 4, 2)
gcorners
=
center_to_corner2d
(
gt_boxes
[:,
:
2
],
gt_boxes
[:,
3
:
5
])
# (N, 4, 2)
inter_max_xy
=
torch
.
minimum
(
qcorners
[:,
2
],
gcorners
[:,
2
])
inter_min_xy
=
torch
.
maximum
(
qcorners
[:,
0
],
gcorners
[:,
0
])
out_max_xy
=
torch
.
maximum
(
qcorners
[:,
2
],
gcorners
[:,
2
])
out_min_xy
=
torch
.
minimum
(
qcorners
[:,
0
],
gcorners
[:,
0
])
# calculate area
volume_pred_boxes
=
pred_boxes
[:,
3
]
*
pred_boxes
[:,
4
]
*
pred_boxes
[:,
5
]
volume_gt_boxes
=
gt_boxes
[:,
3
]
*
gt_boxes
[:,
4
]
*
gt_boxes
[:,
5
]
inter_h
=
torch
.
minimum
(
pred_boxes
[:,
2
]
+
0.5
*
pred_boxes
[:,
5
],
gt_boxes
[:,
2
]
+
0.5
*
gt_boxes
[:,
5
])
-
\
torch
.
maximum
(
pred_boxes
[:,
2
]
-
0.5
*
pred_boxes
[:,
5
],
gt_boxes
[:,
2
]
-
0.5
*
gt_boxes
[:,
5
])
inter_h
=
torch
.
clamp
(
inter_h
,
min
=
0
)
inter
=
torch
.
clamp
((
inter_max_xy
-
inter_min_xy
),
min
=
0
)
volume_inter
=
inter
[:,
0
]
*
inter
[:,
1
]
*
inter_h
volume_union
=
volume_gt_boxes
+
volume_pred_boxes
-
volume_inter
# boxes_iou3d_gpu(pred_boxes, gt_boxes)
inter_diag
=
torch
.
pow
(
gt_boxes
[:,
0
:
3
]
-
pred_boxes
[:,
0
:
3
],
2
).
sum
(
-
1
)
outer_h
=
torch
.
maximum
(
gt_boxes
[:,
2
]
+
0.5
*
gt_boxes
[:,
5
],
pred_boxes
[:,
2
]
+
0.5
*
pred_boxes
[:,
5
])
-
\
torch
.
minimum
(
gt_boxes
[:,
2
]
-
0.5
*
gt_boxes
[:,
5
],
pred_boxes
[:,
2
]
-
0.5
*
pred_boxes
[:,
5
])
outer_h
=
torch
.
clamp
(
outer_h
,
min
=
0
)
outer
=
torch
.
clamp
((
out_max_xy
-
out_min_xy
),
min
=
0
)
outer_diag
=
outer
[:,
0
]
**
2
+
outer
[:,
1
]
**
2
+
outer_h
**
2
dious
=
volume_inter
/
volume_union
-
inter_diag
/
outer_diag
dious
=
torch
.
clamp
(
dious
,
min
=-
1.0
,
max
=
1.0
)
return
dious
\ No newline at end of file
pcdet/utils/loss_utils.py
View file @
72c608ce
...
...
@@ -605,4 +605,45 @@ class GaussianFocalLoss(nn.Module):
pos_loss
=
-
(
pred
+
eps
).
log
()
*
(
1
-
pred
).
pow
(
self
.
alpha
)
*
pos_weights
neg_loss
=
-
(
1
-
pred
+
eps
).
log
()
*
pred
.
pow
(
self
.
alpha
)
*
neg_weights
return
pos_loss
+
neg_loss
\ No newline at end of file
return
pos_loss
+
neg_loss
def
calculate_iou_loss_centerhead
(
iou_preds
,
batch_box_preds
,
mask
,
ind
,
gt_boxes
):
"""
Args:
iou_preds: (batch x 1 x h x w)
batch_box_preds: (batch x (7 or 9) x h x w)
mask: (batch x max_objects)
ind: (batch x max_objects)
gt_boxes: (batch x N, 7 or 9)
Returns:
"""
if
mask
.
sum
()
==
0
:
return
iou_preds
.
new_zeros
((
1
))
mask
=
mask
.
bool
()
selected_iou_preds
=
_transpose_and_gather_feat
(
iou_preds
,
ind
)[
mask
]
selected_box_preds
=
_transpose_and_gather_feat
(
batch_box_preds
,
ind
)[
mask
]
iou_target
=
iou3d_nms_utils
.
paired_boxes_iou3d_gpu
(
selected_box_preds
[:,
0
:
7
],
gt_boxes
[
mask
][:,
0
:
7
])
# iou_target = iou3d_nms_utils.boxes_iou3d_gpu(selected_box_preds[:, 0:7].clone(), gt_boxes[mask][:, 0:7].clone()).diag()
iou_target
=
iou_target
*
2
-
1
# [0, 1] ==> [-1, 1]
# print(selected_iou_preds.view(-1), iou_target)
loss
=
F
.
l1_loss
(
selected_iou_preds
.
view
(
-
1
),
iou_target
,
reduction
=
'sum'
)
loss
=
loss
/
torch
.
clamp
(
mask
.
sum
(),
min
=
1e-4
)
return
loss
def
calculate_iou_reg_loss_centerhead
(
batch_box_preds
,
mask
,
ind
,
gt_boxes
):
if
mask
.
sum
()
==
0
:
return
batch_box_preds
.
new_zeros
((
1
))
mask
=
mask
.
bool
()
selected_box_preds
=
_transpose_and_gather_feat
(
batch_box_preds
,
ind
)
iou
=
box_utils
.
bbox3d_overlaps_diou
(
selected_box_preds
[
mask
][:,
0
:
7
],
gt_boxes
[
mask
][:,
0
:
7
])
loss
=
(
1.0
-
iou
).
sum
()
/
torch
.
clamp
(
mask
.
sum
(),
min
=
1e-4
)
return
loss
setup.py
View file @
72c608ce
...
...
@@ -125,5 +125,13 @@ if __name__ == '__main__':
"src/bev_pool_cuda.cu"
,
],
),
make_cuda_ext
(
name
=
'ingroup_inds_cuda'
,
module
=
'pcdet.ops.ingroup_inds'
,
sources
=
[
'src/ingroup_inds.cpp'
,
'src/ingroup_inds_kernel.cu'
,
]
),
],
)
tools/cfgs/waymo_models/dsvt_pillar.yaml
0 → 100644
View file @
72c608ce
CLASS_NAMES
:
[
'
Vehicle'
,
'
Pedestrian'
,
'
Cyclist'
]
DATA_CONFIG
:
_BASE_CONFIG_
:
cfgs/dataset_configs/waymo_dataset.yaml
SAMPLED_INTERVAL
:
{
'
train'
:
1
,
'
test'
:
1
}
POINT_CLOUD_RANGE
:
[
-74.88
,
-74.88
,
-2
,
74.88
,
74.88
,
4.0
]
POINTS_TANH_DIM
:
[
3
,
4
]
DATA_AUGMENTOR
:
DISABLE_AUG_LIST
:
[
'
placeholder'
]
AUG_CONFIG_LIST
:
-
NAME
:
gt_sampling
USE_ROAD_PLANE
:
False
DB_INFO_PATH
:
-
waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl
USE_SHARED_MEMORY
:
True
# set it to True to speed up (it costs about 15GB shared memory)
DB_DATA_PATH
:
-
waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global.npy
BACKUP_DB_INFO
:
# if the above DB_INFO cannot be found, will use this backup one
DB_INFO_PATH
:
waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0.pkl
DB_DATA_PATH
:
waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_global.npy
NUM_POINT_FEATURES
:
6
PREPARE
:
{
filter_by_min_points
:
[
'
Vehicle:5'
,
'
Pedestrian:10'
,
'
Cyclist:10'
],
filter_by_difficulty
:
[
-1
],
}
SAMPLE_GROUPS
:
[
'
Vehicle:15'
,
'
Pedestrian:10'
,
'
Cyclist:10'
]
NUM_POINT_FEATURES
:
5
REMOVE_EXTRA_WIDTH
:
[
0.0
,
0.0
,
0.0
]
LIMIT_WHOLE_SCENE
:
True
-
NAME
:
random_world_flip
ALONG_AXIS_LIST
:
[
'
x'
,
'
y'
]
-
NAME
:
random_world_rotation
WORLD_ROT_ANGLE
:
[
-0.78539816
,
0.78539816
]
-
NAME
:
random_world_scaling
WORLD_SCALE_RANGE
:
[
0.95
,
1.05
]
-
NAME
:
random_world_translation
NOISE_TRANSLATE_STD
:
[
0.5
,
0.5
,
0.5
]
DATA_PROCESSOR
:
-
NAME
:
mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES
:
True
-
NAME
:
shuffle_points
SHUFFLE_ENABLED
:
{
'
train'
:
True
,
'
test'
:
True
}
-
NAME
:
transform_points_to_voxels_placeholder
VOXEL_SIZE
:
[
0.32
,
0.32
,
6
]
MODEL
:
NAME
:
CenterPoint
VFE
:
NAME
:
DynamicVoxelVFE
WITH_DISTANCE
:
False
USE_ABSLOTE_XYZ
:
True
USE_NORM
:
True
NUM_FILTERS
:
[
192
,
192
]
BACKBONE_3D
:
NAME
:
DSVT
INPUT_LAYER
:
sparse_shape
:
[
468
,
468
,
1
]
downsample_stride
:
[]
d_model
:
[
192
]
set_info
:
[[
36
,
4
]]
window_shape
:
[[
12
,
12
,
1
]]
hybrid_factor
:
[
2
,
2
,
1
]
# x, y, z
shifts_list
:
[[[
0
,
0
,
0
],
[
6
,
6
,
0
]]]
normalize_pos
:
False
block_name
:
[
'
DSVTBlock'
]
set_info
:
[[
36
,
4
]]
d_model
:
[
192
]
nhead
:
[
8
]
dim_feedforward
:
[
384
]
dropout
:
0.0
activation
:
gelu
output_shape
:
[
468
,
468
]
conv_out_channel
:
192
# You can enable torch.utils.checkpoint to save GPU memory
USE_CHECKPOINT
:
True
MAP_TO_BEV
:
NAME
:
PointPillarScatter3d
INPUT_SHAPE
:
[
468
,
468
,
1
]
NUM_BEV_FEATURES
:
192
BACKBONE_2D
:
NAME
:
BaseBEVResBackbone
LAYER_NUMS
:
[
1
,
2
,
2
]
LAYER_STRIDES
:
[
1
,
2
,
2
]
NUM_FILTERS
:
[
128
,
128
,
256
]
UPSAMPLE_STRIDES
:
[
1
,
2
,
4
]
NUM_UPSAMPLE_FILTERS
:
[
128
,
128
,
128
]
DENSE_HEAD
:
NAME
:
CenterHead
CLASS_AGNOSTIC
:
False
CLASS_NAMES_EACH_HEAD
:
[
[
'
Vehicle'
,
'
Pedestrian'
,
'
Cyclist'
]
]
SHARED_CONV_CHANNEL
:
64
USE_BIAS_BEFORE_NORM
:
False
NUM_HM_CONV
:
2
BN_EPS
:
0.001
BN_MOM
:
0.01
SEPARATE_HEAD_CFG
:
HEAD_ORDER
:
[
'
center'
,
'
center_z'
,
'
dim'
,
'
rot'
]
HEAD_DICT
:
{
'
center'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
center_z'
:
{
'
out_channels'
:
1
,
'
num_conv'
:
2
},
'
dim'
:
{
'
out_channels'
:
3
,
'
num_conv'
:
2
},
'
rot'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
iou'
:
{
'
out_channels'
:
1
,
'
num_conv'
:
2
},
}
TARGET_ASSIGNER_CONFIG
:
FEATURE_MAP_STRIDE
:
1
NUM_MAX_OBJS
:
500
GAUSSIAN_OVERLAP
:
0.1
MIN_RADIUS
:
2
IOU_REG_LOSS
:
True
LOSS_CONFIG
:
LOSS_WEIGHTS
:
{
'
cls_weight'
:
1.0
,
'
loc_weight'
:
2.0
,
'
code_weights'
:
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
]
}
POST_PROCESSING
:
SCORE_THRESH
:
0.1
POST_CENTER_LIMIT_RANGE
:
[
-80
,
-80
,
-10.0
,
80
,
80
,
10.0
]
MAX_OBJ_PER_SAMPLE
:
500
USE_IOU_TO_RECTIFY_SCORE
:
True
IOU_RECTIFIER
:
[
0.68
,
0.71
,
0.65
]
NMS_CONFIG
:
NMS_TYPE
:
class_specific_nms
NMS_THRESH
:
[
0.75
,
0.6
,
0.55
]
NMS_PRE_MAXSIZE
:
[
4096
,
4096
,
4096
]
NMS_POST_MAXSIZE
:
[
500
,
500
,
500
]
POST_PROCESSING
:
RECALL_THRESH_LIST
:
[
0.3
,
0.5
,
0.7
]
EVAL_METRIC
:
waymo
OPTIMIZATION
:
BATCH_SIZE_PER_GPU
:
3
NUM_EPOCHS
:
24
OPTIMIZER
:
adam_onecycle
LR
:
0.003
#
WEIGHT_DECAY
:
0.05
MOMENTUM
:
0.9
MOMS
:
[
0.95
,
0.85
]
PCT_START
:
0.1
DIV_FACTOR
:
100
DECAY_STEP_LIST
:
[
35
,
45
]
LR_DECAY
:
0.1
LR_CLIP
:
0.0000001
LR_WARMUP
:
False
WARMUP_EPOCH
:
1
GRAD_NORM_CLIP
:
10
LOSS_SCALE_FP16
:
32.0
HOOK
:
DisableAugmentationHook
:
DISABLE_AUG_LIST
:
[
'
gt_sampling'
,
'
random_world_flip'
,
'
random_world_rotation'
,
'
random_world_scaling'
,
'
random_world_translation'
]
NUM_LAST_EPOCHS
:
1
\ No newline at end of file
tools/cfgs/waymo_models/dsvt_voxel.yaml
0 → 100644
View file @
72c608ce
CLASS_NAMES
:
[
'
Vehicle'
,
'
Pedestrian'
,
'
Cyclist'
]
DATA_CONFIG
:
_BASE_CONFIG_
:
cfgs/dataset_configs/waymo_dataset.yaml
SAMPLED_INTERVAL
:
{
'
train'
:
1
,
'
test'
:
1
}
POINT_CLOUD_RANGE
:
[
-74.88
,
-74.88
,
-2
,
74.88
,
74.88
,
4.0
]
POINTS_TANH_DIM
:
[
3
,
4
]
DATA_AUGMENTOR
:
DISABLE_AUG_LIST
:
[
'
placeholder'
]
AUG_CONFIG_LIST
:
-
NAME
:
gt_sampling
USE_ROAD_PLANE
:
False
DB_INFO_PATH
:
-
waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl
USE_SHARED_MEMORY
:
True
# set it to True to speed up (it costs about 15GB shared memory)
DB_DATA_PATH
:
-
waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global.npy
BACKUP_DB_INFO
:
# if the above DB_INFO cannot be found, will use this backup one
DB_INFO_PATH
:
waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0.pkl
DB_DATA_PATH
:
waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_global.npy
NUM_POINT_FEATURES
:
6
PREPARE
:
{
filter_by_min_points
:
[
'
Vehicle:5'
,
'
Pedestrian:10'
,
'
Cyclist:10'
],
filter_by_difficulty
:
[
-1
],
}
SAMPLE_GROUPS
:
[
'
Vehicle:15'
,
'
Pedestrian:10'
,
'
Cyclist:10'
]
NUM_POINT_FEATURES
:
5
REMOVE_EXTRA_WIDTH
:
[
0.0
,
0.0
,
0.0
]
LIMIT_WHOLE_SCENE
:
True
-
NAME
:
random_world_flip
ALONG_AXIS_LIST
:
[
'
x'
,
'
y'
]
-
NAME
:
random_world_rotation
WORLD_ROT_ANGLE
:
[
-0.78539816
,
0.78539816
]
-
NAME
:
random_world_scaling
WORLD_SCALE_RANGE
:
[
0.95
,
1.05
]
-
NAME
:
random_world_translation
NOISE_TRANSLATE_STD
:
[
0.5
,
0.5
,
0.5
]
DATA_PROCESSOR
:
-
NAME
:
mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES
:
True
-
NAME
:
shuffle_points
SHUFFLE_ENABLED
:
{
'
train'
:
True
,
'
test'
:
True
}
-
NAME
:
transform_points_to_voxels_placeholder
VOXEL_SIZE
:
[
0.32
,
0.32
,
0.1875
]
MODEL
:
NAME
:
CenterPoint
VFE
:
NAME
:
DynamicVoxelVFE
WITH_DISTANCE
:
False
USE_ABSLOTE_XYZ
:
True
USE_NORM
:
True
NUM_FILTERS
:
[
192
,
192
]
BACKBONE_3D
:
NAME
:
DSVT
INPUT_LAYER
:
sparse_shape
:
[
468
,
468
,
32
]
downsample_stride
:
[[
1
,
1
,
4
],
[
1
,
1
,
4
],
[
1
,
1
,
2
]]
d_model
:
[
192
,
192
,
192
,
192
]
set_info
:
[[
48
,
1
],
[
48
,
1
],
[
48
,
1
],
[
48
,
1
]]
window_shape
:
[[
12
,
12
,
32
],
[
12
,
12
,
8
],
[
12
,
12
,
2
],
[
12
,
12
,
1
]]
hybrid_factor
:
[
2
,
2
,
1
]
# x, y, z
shifts_list
:
[[[
0
,
0
,
0
],
[
6
,
6
,
0
]],
[[
0
,
0
,
0
],
[
6
,
6
,
0
]],
[[
0
,
0
,
0
],
[
6
,
6
,
0
]],
[[
0
,
0
,
0
],
[
6
,
6
,
0
]]]
normalize_pos
:
False
block_name
:
[
'
DSVTBlock'
,
'
DSVTBlock'
,
'
DSVTBlock'
,
'
DSVTBlock'
]
set_info
:
[[
48
,
1
],
[
48
,
1
],
[
48
,
1
],
[
48
,
1
]]
d_model
:
[
192
,
192
,
192
,
192
]
nhead
:
[
8
,
8
,
8
,
8
]
dim_feedforward
:
[
384
,
384
,
384
,
384
]
dropout
:
0.0
activation
:
gelu
reduction_type
:
'
attention'
output_shape
:
[
468
,
468
]
conv_out_channel
:
192
# You can enable torch.utils.checkpoint to save GPU memory
# USE_CHECKPOINT: True
MAP_TO_BEV
:
NAME
:
PointPillarScatter3d
INPUT_SHAPE
:
[
468
,
468
,
1
]
NUM_BEV_FEATURES
:
192
BACKBONE_2D
:
NAME
:
BaseBEVResBackbone
LAYER_NUMS
:
[
1
,
2
,
2
]
LAYER_STRIDES
:
[
1
,
2
,
2
]
NUM_FILTERS
:
[
128
,
128
,
256
]
UPSAMPLE_STRIDES
:
[
1
,
2
,
4
]
NUM_UPSAMPLE_FILTERS
:
[
128
,
128
,
128
]
DENSE_HEAD
:
NAME
:
CenterHead
CLASS_AGNOSTIC
:
False
CLASS_NAMES_EACH_HEAD
:
[
[
'
Vehicle'
,
'
Pedestrian'
,
'
Cyclist'
]
]
SHARED_CONV_CHANNEL
:
64
USE_BIAS_BEFORE_NORM
:
False
NUM_HM_CONV
:
2
BN_EPS
:
0.001
BN_MOM
:
0.01
SEPARATE_HEAD_CFG
:
HEAD_ORDER
:
[
'
center'
,
'
center_z'
,
'
dim'
,
'
rot'
]
HEAD_DICT
:
{
'
center'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
center_z'
:
{
'
out_channels'
:
1
,
'
num_conv'
:
2
},
'
dim'
:
{
'
out_channels'
:
3
,
'
num_conv'
:
2
},
'
rot'
:
{
'
out_channels'
:
2
,
'
num_conv'
:
2
},
'
iou'
:
{
'
out_channels'
:
1
,
'
num_conv'
:
2
},
}
TARGET_ASSIGNER_CONFIG
:
FEATURE_MAP_STRIDE
:
1
NUM_MAX_OBJS
:
500
GAUSSIAN_OVERLAP
:
0.1
MIN_RADIUS
:
2
IOU_REG_LOSS
:
True
LOSS_CONFIG
:
LOSS_WEIGHTS
:
{
'
cls_weight'
:
1.0
,
'
loc_weight'
:
2.0
,
'
code_weights'
:
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
]
}
POST_PROCESSING
:
SCORE_THRESH
:
0.1
POST_CENTER_LIMIT_RANGE
:
[
-80
,
-80
,
-10.0
,
80
,
80
,
10.0
]
MAX_OBJ_PER_SAMPLE
:
500
USE_IOU_TO_RECTIFY_SCORE
:
True
IOU_RECTIFIER
:
[
0.68
,
0.71
,
0.65
]
NMS_CONFIG
:
NMS_TYPE
:
class_specific_nms
NMS_THRESH
:
[
0.75
,
0.6
,
0.55
]
NMS_PRE_MAXSIZE
:
[
4096
,
4096
,
4096
]
NMS_POST_MAXSIZE
:
[
500
,
500
,
500
]
POST_PROCESSING
:
RECALL_THRESH_LIST
:
[
0.3
,
0.5
,
0.7
]
EVAL_METRIC
:
waymo
OPTIMIZATION
:
BATCH_SIZE_PER_GPU
:
3
NUM_EPOCHS
:
24
OPTIMIZER
:
adam_onecycle
LR
:
0.003
WEIGHT_DECAY
:
0.05
MOMENTUM
:
0.9
MOMS
:
[
0.95
,
0.85
]
PCT_START
:
0.1
DIV_FACTOR
:
100
DECAY_STEP_LIST
:
[
35
,
45
]
LR_DECAY
:
0.1
LR_CLIP
:
0.0000001
LR_WARMUP
:
False
WARMUP_EPOCH
:
1
GRAD_NORM_CLIP
:
10
LOSS_SCALE_FP16
:
32.0
HOOK
:
DisableAugmentationHook
:
DISABLE_AUG_LIST
:
[
'
gt_sampling'
,
'
random_world_flip'
,
'
random_world_rotation'
,
'
random_world_scaling'
,
'
random_world_translation'
]
NUM_LAST_EPOCHS
:
1
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment