Unverified Commit 83954d03 authored by yukang's avatar yukang Committed by GitHub
Browse files

Add support for VoxelNeXt (#1309)

* VoxelNeXt
parent 31f6758a
......@@ -7,5 +7,5 @@
#include <cuda_runtime_api.h>
int boxes_iou_bev_cpu(at::Tensor boxes_a_tensor, at::Tensor boxes_b_tensor, at::Tensor ans_iou_tensor);
int boxes_aligned_iou_bev_cpu(at::Tensor boxes_a_tensor, at::Tensor boxes_b_tensor, at::Tensor ans_iou_tensor);
#endif
......@@ -39,13 +39,36 @@ inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=t
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
void boxesalignedoverlapLauncher(const int num_box, const float *boxes_a, const float *boxes_b, float *ans_overlap);
void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap);
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou);
void nmsLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh);
void nmsNormalLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh);
int boxes_aligned_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
// params ans_overlap: (N, 1)
CHECK_INPUT(boxes_a);
CHECK_INPUT(boxes_b);
CHECK_INPUT(ans_overlap);
int num_box = boxes_a.size(0);
int num_b = boxes_b.size(0);
assert(num_box == num_b);
const float * boxes_a_data = boxes_a.data<float>();
const float * boxes_b_data = boxes_b.data<float>();
float * ans_overlap_data = ans_overlap.data<float>();
boxesalignedoverlapLauncher(num_box, boxes_a_data, boxes_b_data, ans_overlap_data);
return 1;
}
int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
......
......@@ -3,9 +3,11 @@
#include <torch/serialize/tensor.h>
#include <vector>
#include <assert.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
int boxes_aligned_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap);
int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap);
int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou);
int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh);
......
......@@ -9,9 +9,11 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("boxes_aligned_overlap_bev_gpu", &boxes_aligned_overlap_bev_gpu, "aligned oriented boxes overlap");
m.def("boxes_overlap_bev_gpu", &boxes_overlap_bev_gpu, "oriented boxes overlap");
m.def("boxes_iou_bev_gpu", &boxes_iou_bev_gpu, "oriented boxes iou");
m.def("nms_gpu", &nms_gpu, "oriented nms gpu");
m.def("nms_normal_gpu", &nms_normal_gpu, "nms gpu");
m.def("boxes_aligned_iou_bev_cpu", &boxes_aligned_iou_bev_cpu, "aligned oriented boxes iou");
m.def("boxes_iou_bev_cpu", &boxes_iou_bev_cpu, "oriented boxes iou");
}
......@@ -248,6 +248,19 @@ __global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a, cons
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
}
__global__ void boxes_aligned_overlap_kernel(const int num_box, const float *boxes_a, const float *boxes_b, float *ans_overlap){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
const int idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (idx >= num_box){
return;
}
const float * cur_box_a = boxes_a + idx * 7;
const float * cur_box_b = boxes_b + idx * 7;
float s_overlap = box_overlap(cur_box_a, cur_box_b);
ans_overlap[idx] = s_overlap;
}
__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou){
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
......@@ -386,6 +399,17 @@ void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b
#endif
}
void boxesalignedoverlapLauncher(const int num_box, const float *boxes_a, const float *boxes_b, float *ans_overlap){
dim3 blocks(DIVUP(num_box, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
boxes_aligned_overlap_kernel<<<blocks, threads>>>(num_box, boxes_a, boxes_b, ans_overlap);
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou){
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
......
......@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from . import box_utils
from pcdet.ops.iou3d_nms import iou3d_nms_utils
class SigmoidFocalClassificationLoss(nn.Module):
......@@ -300,6 +301,37 @@ def neg_loss_cornernet(pred, gt, mask=None):
return loss
def neg_loss_sparse(pred, gt):
"""
Refer to https://github.com/tianweiy/CenterPoint.
Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
Args:
pred: (batch x c x n)
gt: (batch x c x n)
Returns:
"""
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
class FocalLossCenterNet(nn.Module):
"""
Refer to https://github.com/tianweiy/CenterPoint
......@@ -385,3 +417,147 @@ class RegLossCenterNet(nn.Module):
pred = _transpose_and_gather_feat(output, ind)
loss = _reg_loss(pred, target, mask)
return loss
class FocalLossSparse(nn.Module):
"""
Refer to https://github.com/tianweiy/CenterPoint
"""
def __init__(self):
super(FocalLossSparse, self).__init__()
self.neg_loss = neg_loss_sparse
def forward(self, out, target):
return self.neg_loss(out, target)
class RegLossSparse(nn.Module):
"""
Refer to https://github.com/tianweiy/CenterPoint
"""
def __init__(self):
super(RegLossSparse, self).__init__()
def forward(self, output, mask, ind=None, target=None, batch_index=None):
"""
Args:
output: (N x dim)
mask: (batch x max_objects)
ind: (batch x max_objects)
target: (batch x max_objects x dim)
Returns:
"""
pred = []
batch_size = mask.shape[0]
for bs_idx in range(batch_size):
batch_inds = batch_index==bs_idx
pred.append(output[batch_inds][ind[bs_idx]])
pred = torch.stack(pred)
loss = _reg_loss(pred, target, mask)
return loss
class IouLossSparse(nn.Module):
'''IouLoss loss for an output tensor
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def __init__(self):
super(IouLossSparse, self).__init__()
def forward(self, iou_pred, mask, ind, box_pred, box_gt, batch_index):
if mask.sum() == 0:
return iou_pred.new_zeros((1))
batch_size = mask.shape[0]
mask = mask.bool()
loss = 0
for bs_idx in range(batch_size):
batch_inds = batch_index==bs_idx
pred = iou_pred[batch_inds][ind[bs_idx]][mask[bs_idx]]
pred_box = box_pred[batch_inds][ind[bs_idx]][mask[bs_idx]]
target = iou3d_nms_utils.boxes_aligned_iou3d_gpu(pred_box, box_gt[bs_idx])
target = 2 * target - 1
loss += F.l1_loss(pred, target, reduction='sum')
loss = loss / (mask.sum() + 1e-4)
return loss
class IouRegLossSparse(nn.Module):
'''Distance IoU loss for output boxes
Arguments:
output (batch x dim x h x w)
mask (batch x max_objects)
ind (batch x max_objects)
target (batch x max_objects x dim)
'''
def __init__(self, type="DIoU"):
super(IouRegLossSparse, self).__init__()
def center_to_corner2d(self, center, dim):
corners_norm = torch.tensor([[-0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], [0.5, -0.5]],
dtype=torch.float32, device=dim.device)
corners = dim.view([-1, 1, 2]) * corners_norm.view([1, 4, 2])
corners = corners + center.view(-1, 1, 2)
return corners
def bbox3d_iou_func(self, pred_boxes, gt_boxes):
assert pred_boxes.shape[0] == gt_boxes.shape[0]
qcorners = self.center_to_corner2d(pred_boxes[:, :2], pred_boxes[:, 3:5])
gcorners = self.center_to_corner2d(gt_boxes[:, :2], gt_boxes[:, 3:5])
inter_max_xy = torch.minimum(qcorners[:, 2], gcorners[:, 2])
inter_min_xy = torch.maximum(qcorners[:, 0], gcorners[:, 0])
out_max_xy = torch.maximum(qcorners[:, 2], gcorners[:, 2])
out_min_xy = torch.minimum(qcorners[:, 0], gcorners[:, 0])
# calculate area
volume_pred_boxes = pred_boxes[:, 3] * pred_boxes[:, 4] * pred_boxes[:, 5]
volume_gt_boxes = gt_boxes[:, 3] * gt_boxes[:, 4] * gt_boxes[:, 5]
inter_h = torch.minimum(pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5], gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5]) - \
torch.maximum(pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5], gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5])
inter_h = torch.clamp(inter_h, min=0)
inter = torch.clamp((inter_max_xy - inter_min_xy), min=0)
volume_inter = inter[:, 0] * inter[:, 1] * inter_h
volume_union = volume_gt_boxes + volume_pred_boxes - volume_inter
# boxes_iou3d_gpu(pred_boxes, gt_boxes)
inter_diag = torch.pow(gt_boxes[:, 0:3] - pred_boxes[:, 0:3], 2).sum(-1)
outer_h = torch.maximum(gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5], pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5]) - \
torch.minimum(gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5], pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5])
outer_h = torch.clamp(outer_h, min=0)
outer = torch.clamp((out_max_xy - out_min_xy), min=0)
outer_diag = outer[:, 0] ** 2 + outer[:, 1] ** 2 + outer_h ** 2
dious = volume_inter / volume_union - inter_diag / outer_diag
dious = torch.clamp(dious, min=-1.0, max=1.0)
return dious
def forward(self, box_pred, mask, ind, box_gt, batch_index):
if mask.sum() == 0:
return box_pred.new_zeros((1))
mask = mask.bool()
batch_size = mask.shape[0]
loss = 0
for bs_idx in range(batch_size):
batch_inds = batch_index==bs_idx
pred_box = box_pred[batch_inds][ind[bs_idx]]
iou = self.bbox3d_iou_func(pred_box[mask[bs_idx]], box_gt[bs_idx])
loss += (1. - iou).sum()
loss = loss / (mask.sum() + 1e-4)
return loss
\ No newline at end of file
CLASS_NAMES: ['Regular_vehicle', 'Pedestrian', 'Bicyclist', 'Motorcyclist', 'Wheeled_rider',
'Bollard', 'Construction_cone', 'Sign', 'Construction_barrel', 'Stop_sign', 'Mobile_pedestrian_crossing_sign',
'Large_vehicle', 'Bus', 'Box_truck', 'Truck', 'Vehicular_trailer', 'Truck_cab', 'School_bus', 'Articulated_bus',
'Message_board_trailer', 'Bicycle', 'Motorcycle', 'Wheeled_device', 'Wheelchair', 'Stroller', 'Dog']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/argo2_dataset.yaml
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.1, 0.1, 0.2]
MAX_POINTS_PER_VOXEL: 20
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
MODEL:
NAME: VoxelNeXt
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8xVoxelNeXt
DENSE_HEAD:
NAME: VoxelNeXtHead
CLASS_AGNOSTIC: False
INPUT_FEATURES: 128
CLASS_NAMES_EACH_HEAD: [
['Regular_vehicle',],
['Pedestrian', 'Bicyclist', 'Motorcyclist', 'Wheeled_rider'],
['Bollard', 'Construction_cone', 'Sign', 'Construction_barrel', 'Stop_sign', 'Mobile_pedestrian_crossing_sign'],
['Large_vehicle', 'Bus', 'Box_truck', 'Truck', 'Vehicular_trailer', 'Truck_cab', 'School_bus', 'Articulated_bus', 'Message_board_trailer'],
['Bicycle', 'Motorcycle', 'Wheeled_device', 'Wheelchair', 'Stroller'],
['Dog'],
]
KERNEL_SIZE_HEAD: 1
SHARED_CONV_CHANNEL: 128
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 0.25,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-200, -200, -20, 200, 200, 20]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 83
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 6
OPTIMIZER: adam_onecycle
LR: 0.003
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
CLASS_NAMES: ['Regular_vehicle', 'Pedestrian', 'Bicyclist', 'Motorcyclist', 'Wheeled_rider',
'Bollard', 'Construction_cone', 'Sign', 'Construction_barrel', 'Stop_sign', 'Mobile_pedestrian_crossing_sign',
'Large_vehicle', 'Bus', 'Box_truck', 'Truck', 'Vehicular_trailer', 'Truck_cab', 'School_bus', 'Articulated_bus',
'Message_board_trailer', 'Bicycle', 'Motorcycle', 'Wheeled_device', 'Wheelchair', 'Stroller', 'Dog']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/argo2_dataset.yaml
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.1, 0.1, 0.2]
MAX_POINTS_PER_VOXEL: 20
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
MODEL:
NAME: VoxelNeXt
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8xVoxelNeXt
DENSE_HEAD:
NAME: VoxelNeXtHead
CLASS_AGNOSTIC: False
INPUT_FEATURES: 128
CLASS_NAMES_EACH_HEAD: [
['Regular_vehicle',],
['Pedestrian', 'Bicyclist', 'Motorcyclist', 'Wheeled_rider'],
['Bollard', 'Construction_cone', 'Sign', 'Construction_barrel', 'Stop_sign', 'Mobile_pedestrian_crossing_sign'],
['Large_vehicle', 'Bus', 'Box_truck', 'Truck', 'Vehicular_trailer', 'Truck_cab', 'School_bus', 'Articulated_bus', 'Message_board_trailer'],
['Bicycle', 'Motorcycle', 'Wheeled_device', 'Wheelchair', 'Stroller'],
['Dog'],
]
KERNEL_SIZE_HEAD: 3
SHARED_CONV_CHANNEL: 128
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 0.25,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-200, -200, -20, 200, 200, 20]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 83
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 6
OPTIMIZER: adam_onecycle
LR: 0.003
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
DATASET: 'Argo2Dataset'
DATA_PATH: '../data/argo2'
POINT_CLOUD_RANGE: [-200, -200, -20, 200, 200, 20]
DATA_SPLIT: {
'train': train,
'test': val
}
INFO_PATH: {
'train': [argo2_infos_train.pkl],
'test': [argo2_infos_val.pkl],
}
GET_ITEM_LIST: ["points"]
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.95, 1.05]
POINT_FEATURE_ENCODING: {
encoding_type: absolute_coordinates_encoding,
used_feature_list: ['x', 'y', 'z', 'intensity'],
src_feature_list: ['x', 'y', 'z', 'intensity'],
}
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': False
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.1, 0.1, 0.2]
MAX_POINTS_PER_VOXEL: 5
MAX_NUMBER_OF_VOXELS: {
'train': 160000,
'test': 400000
}
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
INFO_PATH: {
'train': [nuscenes_infos_10sweeps_train.pkl],
'test': [nuscenes_infos_10sweeps_val.pkl],
}
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: gt_sampling
DB_INFO_PATH:
- nuscenes_dbinfos_10sweeps_withvelo.pkl
USE_SHARED_MEMORY: False #True # set it to True to speed up (it costs about 15GB shared memory)
DB_DATA_PATH:
- nuscenes_dbinfos_10sweeps_withvelo_global.pkl.npy
PREPARE: {
filter_by_min_points: [
'car:5','truck:5', 'construction_vehicle:5', 'bus:5', 'trailer:5',
'barrier:5', 'motorcycle:5', 'bicycle:5', 'pedestrian:5', 'traffic_cone:5'
],
}
SAMPLE_GROUPS: [
'car:2','truck:2', 'construction_vehicle:2', 'bus:2', 'trailer:2',
'barrier:2', 'motorcycle:2', 'bicycle:2', 'pedestrian:2', 'traffic_cone:2'
]
NUM_POINT_FEATURES: 5
DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: True
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x', 'y']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.9, 1.1]
- NAME: random_world_translation
NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.075, 0.075, 0.2]
MAX_POINTS_PER_VOXEL: 10
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
MODEL:
NAME: VoxelNeXt
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8xVoxelNeXt
DENSE_HEAD:
NAME: VoxelNeXtHead
CLASS_AGNOSTIC: False
INPUT_FEATURES: 128
CLASS_NAMES_EACH_HEAD: [
['car'],
['truck', 'construction_vehicle'],
['bus', 'trailer'],
['barrier'],
['motorcycle', 'bicycle'],
['pedestrian', 'traffic_cone'],
]
SHARED_CONV_CHANNEL: 128
KERNEL_SIZE_HEAD: 1
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot', 'vel']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'vel': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 0.25,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 83
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 20
OPTIMIZER: adam_onecycle
LR: 0.001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
INFO_PATH: {
'train': [nuscenes_infos_10sweeps_train.pkl],
'test': [nuscenes_infos_10sweeps_val.pkl],
}
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: gt_sampling
DB_INFO_PATH:
- nuscenes_dbinfos_10sweeps_withvelo.pkl
USE_SHARED_MEMORY: True #True # set it to True to speed up (it costs about 15GB shared memory)
DB_DATA_PATH:
- nuscenes_dbinfos_10sweeps_withvelo_global.pkl.npy
PREPARE: {
filter_by_min_points: [
'car:5','truck:5', 'construction_vehicle:5', 'bus:5', 'trailer:5',
'barrier:5', 'motorcycle:5', 'bicycle:5', 'pedestrian:5', 'traffic_cone:5'
],
}
SAMPLE_GROUPS: [
'car:2','truck:2', 'construction_vehicle:2', 'bus:2', 'trailer:2',
'barrier:2', 'motorcycle:2', 'bicycle:2', 'pedestrian:2', 'traffic_cone:2'
]
NUM_POINT_FEATURES: 5
DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: True
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x', 'y']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.9, 1.1]
- NAME: random_world_translation
NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.075, 0.075, 0.2]
MAX_POINTS_PER_VOXEL: 10
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
DOUBLE_FLIP: True
MODEL:
NAME: VoxelNeXt
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8xVoxelNeXt
DENSE_HEAD:
NAME: VoxelNeXtHead
CLASS_AGNOSTIC: False
INPUT_FEATURES: 128
DOUBLE_FLIP: True
CLASS_NAMES_EACH_HEAD: [
['car'],
['truck', 'construction_vehicle'],
['bus', 'trailer'],
['barrier'],
['motorcycle', 'bicycle'],
['pedestrian', 'traffic_cone'],
]
SHARED_CONV_CHANNEL: 128
KERNEL_SIZE_HEAD: 1
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot', 'vel']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'vel': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 0.25,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 83
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 1 #4
NUM_EPOCHS: 20
OPTIMIZER: adam_onecycle
LR: 0.001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml
MODEL:
NAME: VoxelNeXt
VFE:
NAME: DynamicPillarVFESimple2D
WITH_DISTANCE: False
USE_ABSLOTE_XYZ: True
USE_CLUSTER_XYZ: False
USE_NORM: True
NUM_FILTERS: [32]
BACKBONE_3D:
NAME: VoxelResBackBone8xVoxelNeXt2D
DENSE_HEAD:
NAME: VoxelNeXtHead
IOU_BRANCH: True
CLASS_AGNOSTIC: False
INPUT_FEATURES: 256
CLASS_NAMES_EACH_HEAD: [
['Vehicle', 'Pedestrian', 'Cyclist'],
]
SHARED_CONV_CHANNEL: 256
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'iou': {'out_channels': 1, 'num_conv': 2},
}
RECTIFIER: [0.68, 0.71, 0.65]
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-75.2, -75.2, -2, 75.2, 75.2, 4]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: [0.8, 0.55, 0.55] #0.7
NMS_PRE_MAXSIZE: [2048, 1024, 1024] #[4096]
NMS_POST_MAXSIZE: [200, 150, 150] #500
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
EVAL_METRIC: waymo
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 12
OPTIMIZER: adam_onecycle
LR: 0.003
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml
MODEL:
NAME: VoxelNeXt
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8xVoxelNeXt
DENSE_HEAD:
NAME: VoxelNeXtHead
IOU_BRANCH: True
CLASS_AGNOSTIC: False
INPUT_FEATURES: 128
CLASS_NAMES_EACH_HEAD: [
['Vehicle', 'Pedestrian', 'Cyclist']
]
SHARED_CONV_CHANNEL: 128
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'iou': {'out_channels': 1, 'num_conv': 2},
}
RECTIFIER: [0.68, 0.71, 0.65]
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-75.2, -75.2, -2, 75.2, 75.2, 4]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: [0.8, 0.55, 0.55] #0.7
NMS_PRE_MAXSIZE: [2048, 1024, 1024] #[4096]
NMS_POST_MAXSIZE: [200, 150, 150] #500
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
EVAL_METRIC: waymo
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 12
OPTIMIZER: adam_onecycle
LR: 0.003
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment