Commit 29ad7c89 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support PointRCNNHead with roipoint_pool3d, and pointrcnn.yaml

parent 33c0632a
from .roi_head_template import RoIHeadTemplate from .roi_head_template import RoIHeadTemplate
from .partA2_head import PartA2FCHead from .partA2_head import PartA2FCHead
from .pvrcnn_head import PVRCNNHead from .pvrcnn_head import PVRCNNHead
from .pointrcnn_head import PointRCNNHead
__all__ = { __all__ = {
'RoIHeadTemplate': RoIHeadTemplate, 'RoIHeadTemplate': RoIHeadTemplate,
'PartA2FCHead': PartA2FCHead, 'PartA2FCHead': PartA2FCHead,
'PVRCNNHead': PVRCNNHead 'PVRCNNHead': PVRCNNHead,
'PointRCNNHead': PointRCNNHead
} }
import torch
import torch.nn as nn
from .roi_head_template import RoIHeadTemplate
from ...ops.pointnet2.pointnet2_batch import pointnet2_modules
from ...ops.roipoint_pool3d import roipoint_pool3d_utils
from ...utils import common_utils
class PointRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
use_bn = self.model_cfg.USE_BN
self.SA_modules = nn.ModuleList()
channel_in = input_channels
self.num_prefix_channels = 3 + 2 # xyz + point_scores + point_depth
xyz_mlps = [self.num_prefix_channels] + self.model_cfg.XYZ_UP_LAYER
shared_mlps = []
for k in range(len(xyz_mlps) - 1):
shared_mlps.append(nn.Conv2d(xyz_mlps[k], xyz_mlps[k + 1], kernel_size=1, bias=not use_bn))
if use_bn:
shared_mlps.append(nn.BatchNorm2d(xyz_mlps[k + 1]))
shared_mlps.append(nn.ReLU())
self.xyz_up_layer = nn.Sequential(*shared_mlps)
c_out = self.model_cfg.XYZ_UP_LAYER[-1]
self.merge_down_layer = nn.Sequential(
nn.Conv2d(c_out * 2, c_out, kernel_size=1, bias=not use_bn),
*[nn.BatchNorm2d(c_out), nn.ReLU()] if use_bn else [nn.ReLU()]
)
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
mlps = [channel_in] + self.model_cfg.SA_CONFIG.MLPS[k]
npoint = self.model_cfg.SA_CONFIG.NPOINTS[k] if self.model_cfg.SA_CONFIG.NPOINTS[k] != -1 else None
self.SA_modules.append(
pointnet2_modules.PointnetSAModule(
npoint=npoint,
radius=self.model_cfg.SA_CONFIG.RADIUS[k],
nsample=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlp=mlps,
use_xyz=True,
bn=use_bn
)
)
channel_in = mlps[-1]
self.cls_layers = self.make_fc_layers(
input_channels=channel_in, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
)
self.reg_layers = self.make_fc_layers(
input_channels=channel_in,
output_channels=self.box_coder.code_size * self.num_class,
fc_list=self.model_cfg.REG_FC
)
self.roipoint_pool3d_layer = roipoint_pool3d_utils.RoIPointPool3d(
num_sampled_points=self.model_cfg.ROI_POINT_POOL.NUM_SAMPLED_POINTS,
pool_extra_width=self.model_cfg.ROI_POINT_POOL.POOL_EXTRA_WIDTH
)
self.init_weights(weight_init='xavier')
def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
elif weight_init == 'xavier':
init_func = nn.init.xavier_normal_
elif weight_init == 'normal':
init_func = nn.init.normal_
else:
raise NotImplementedError
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
if weight_init == 'normal':
init_func(m.weight, mean=0, std=0.001)
else:
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
def roipool3d_gpu(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
point_coords: (num_points, 4) [bs_idx, x, y, z]
point_features: (num_points, C)
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
Returns:
"""
batch_size = batch_dict['batch_size']
batch_idx = batch_dict['point_coords'][:, 0]
point_coords = batch_dict['point_coords'][:, 1:4]
point_features = batch_dict['point_features']
rois = batch_dict['rois'] # (B, num_rois, 7 + C)
batch_cnt = point_coords.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
assert batch_cnt.min() == batch_cnt.max()
point_scores = batch_dict['point_cls_scores'].detach()
point_depths = point_coords.norm(dim=1) / self.model_cfg.ROI_POINT_POOL.DEPTH_NORMALIZER - 0.5
point_features_list = [point_scores[:, None], point_depths[:, None], point_features]
point_features_all = torch.cat(point_features_list, dim=1)
batch_points = point_coords.view(batch_size, -1, 3)
batch_point_features = point_features_all.view(batch_size, -1, point_features_all.shape[-1])
with torch.no_grad():
pooled_features, pooled_empty_flag = self.roipoint_pool3d_layer(
batch_points, batch_point_features, rois
) # pooled_features: (B, num_rois, num_sampled_points, 3 + C), pooled_empty_flag: (B, num_rois)
# canonical transformation
roi_center = rois[:, :, 0:3]
pooled_features[:, :, :, 0:3] -= roi_center.unsqueeze(dim=2)
pooled_features = pooled_features.view(-1, pooled_features.shape[-2], pooled_features.shape[-1])
pooled_features[:, :, 0:3] = common_utils.rotate_points_along_z(
pooled_features[:, :, 0:3], -rois.view(-1, rois.shape[-1])[:, 6]
)
pooled_features[pooled_empty_flag.view(-1) > 0] = 0
return pooled_features
def forward(self, batch_dict):
"""
Args:
batch_dict:
Returns:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
pooled_features = self.roipool3d_gpu(batch_dict) # (total_rois, num_sampled_points, 3 + C)
xyz_input = pooled_features[..., 0:self.num_prefix_channels].transpose(1, 2).unsqueeze(dim=3)
xyz_features = self.xyz_up_layer(xyz_input)
point_features = pooled_features[..., self.num_prefix_channels:].transpose(1, 2).unsqueeze(dim=3)
merged_features = torch.cat((xyz_features, point_features), dim=1)
merged_features = self.merge_down_layer(merged_features)
l_xyz, l_features = [pooled_features[..., 0:3].contiguous()], [merged_features.squeeze(dim=3).contiguous()]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
shared_features = l_features[-1] # (total_rois, num_features, 1)
rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
if not self.training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
)
batch_dict['batch_cls_preds'] = batch_cls_preds
batch_dict['batch_box_preds'] = batch_box_preds
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_cls'] = rcnn_cls
targets_dict['rcnn_reg'] = rcnn_reg
self.forward_ret_dict = targets_dict
return batch_dict
import torch
import torch.nn as nn
from torch.autograd import Function
from . import roipoint_pool3d_cuda
from ...utils import box_utils
class RoIPointPool3d(nn.Module):
def __init__(self, num_sampled_points=512, pool_extra_width=1.0):
super().__init__()
self.num_sampled_points = num_sampled_points
self.pool_extra_width = pool_extra_width
def forward(self, points, point_features, boxes3d):
"""
Args:
points: (B, N, 3)
point_features: (B, N, C)
boxes3d: (B, M, 7), [x, y, z, dx, dy, dz, heading]
Returns:
pooled_features: (B, M, 512, 3 + C)
pooled_empty_flag: (B, M)
"""
return RoIPointPool3dFunction.apply(
points, point_features, boxes3d, self.pool_extra_width, self.num_sampled_points
)
class RoIPointPool3dFunction(Function):
@staticmethod
def forward(ctx, points, point_features, boxes3d, pool_extra_width, num_sampled_points=512):
"""
Args:
ctx:
points: (B, N, 3)
point_features: (B, N, C)
boxes3d: (B, num_boxes, 7), [x, y, z, dx, dy, dz, heading]
pool_extra_width:
num_sampled_points:
Returns:
pooled_features: (B, num_boxes, 512, 3 + C)
pooled_empty_flag: (B, num_boxes)
"""
assert points.shape.__len__() == 3 and points.shape[2] == 3
batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[1], point_features.shape[2]
pooled_boxes3d = box_utils.enlarge_box3d(boxes3d.view(-1, 7), pool_extra_width).view(batch_size, -1, 7)
pooled_features = point_features.new_zeros((batch_size, boxes_num, num_sampled_points, 3 + feature_len))
pooled_empty_flag = point_features.new_zeros((batch_size, boxes_num)).int()
roipoint_pool3d_cuda.forward(
points.contiguous(), pooled_boxes3d.contiguous(),
point_features.contiguous(), pooled_features, pooled_empty_flag
)
return pooled_features, pooled_empty_flag
@staticmethod
def backward(ctx, grad_out):
raise NotImplementedError
if __name__ == '__main__':
pass
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
void roipool3dLauncher(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag);
int roipool3d_gpu(at::Tensor xyz, at::Tensor boxes3d, at::Tensor pts_feature, at::Tensor pooled_features, at::Tensor pooled_empty_flag){
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)
CHECK_INPUT(xyz);
CHECK_INPUT(boxes3d);
CHECK_INPUT(pts_feature);
CHECK_INPUT(pooled_features);
CHECK_INPUT(pooled_empty_flag);
int batch_size = xyz.size(0);
int pts_num = xyz.size(1);
int boxes_num = boxes3d.size(1);
int feature_in_len = pts_feature.size(2);
int sampled_pts_num = pooled_features.size(2);
const float * xyz_data = xyz.data<float>();
const float * boxes3d_data = boxes3d.data<float>();
const float * pts_feature_data = pts_feature.data<float>();
float * pooled_features_data = pooled_features.data<float>();
int * pooled_empty_flag_data = pooled_empty_flag.data<int>();
roipool3dLauncher(batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_data, boxes3d_data, pts_feature_data, pooled_features_data, pooled_empty_flag_data);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &roipool3d_gpu, "roipool3d forward (CUDA)");
}
/*
Point cloud feature pooling
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <math.h>
#include <stdio.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG
__device__ inline void lidar_to_local_coords(float shift_x, float shift_y, float rot_angle, float &local_x, float &local_y){
float cosa = cos(-rot_angle), sina = sin(-rot_angle);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
__device__ inline int check_pt_in_box3d(const float *pt, const float *box3d, float &local_x, float &local_y){
// param pt: (x, y, z)
// param box3d: [x, y, z, dx, dy, dz, heading] (x, y, z) is the box center
const float MARGIN = 1e-5;
float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6];
if (fabsf(z - cz) > dz / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (fabs(local_x) < dx / 2.0 + MARGIN) & (fabs(local_y) < dy / 2.0 + MARGIN);
return in_flag;
}
__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num, const float *xyz, const float *boxes3d, int *pts_assign){
// params xyz: (B, N, 3)
// params boxes3d: (B, M, 7)
// params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means background points
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;
if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size){
return;
}
int assign_idx = bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx;
pts_assign[assign_idx] = 0;
int box_offset = bs_idx * boxes_num * 7 + box_idx * 7;
int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3;
float local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset, local_x, local_y);
pts_assign[assign_idx] = cur_in_flag;
// printf("bs=%d, pt=%d, in=%d\n", bs_idx, pt_idx, pts_assign[bs_idx * pts_num + pt_idx]);
}
__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num, int sampled_pts_num,
const int *pts_assign, int *pts_idx, int *pooled_empty_flag){
// params xyz: (B, N, 3)
// params pts_feature: (B, N, C)
// params pts_assign: (B, N)
// params pts_idx: (B, M, 512)
// params pooled_empty_flag: (B, M)
int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (boxes_idx >= boxes_num){
return;
}
int bs_idx = blockIdx.y;
int cnt = 0;
for (int k = 0; k < pts_num; k++){
if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + boxes_idx]){
if (cnt < sampled_pts_num){
pts_idx[bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num + cnt] = k;
cnt++;
}
else break;
}
}
if (cnt == 0){
pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1;
}
else if (cnt < sampled_pts_num){
// duplicate same points for sampling
for (int k = cnt; k < sampled_pts_num; k++){
int duplicate_idx = k % cnt;
int base_offset = bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num;
pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx];
}
}
}
__global__ void roipool3d_forward(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const int *pts_idx, const float *pts_feature,
float *pooled_features, int *pooled_empty_flag){
// params xyz: (B, N, 3)
// params pts_idx: (B, M, 512)
// params pts_feature: (B, N, C)
// params pooled_features: (B, M, 512, 3+C)
// params pooled_empty_flag: (B, M)
int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
int bs_idx = blockIdx.z;
if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num || bs_idx >= batch_size){
return;
}
if (pooled_empty_flag[bs_idx * boxes_num + box_idx]){
return;
}
int temp_idx = bs_idx * boxes_num * sampled_pts_num + box_idx * sampled_pts_num + sample_pt_idx;
int src_pt_idx = pts_idx[temp_idx];
int dst_feature_offset = temp_idx * (3 + feature_in_len);
for (int j = 0; j < 3; j++)
pooled_features[dst_feature_offset + j] = xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j];
int src_feature_offset = bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len;
for (int j = 0; j < feature_in_len; j++)
pooled_features[dst_feature_offset + 3 + j] = pts_feature[src_feature_offset + j];
}
void roipool3dLauncher(int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num,
const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag){
// printf("batch_size=%d, pts_num=%d, boxes_num=%d\n", batch_size, pts_num, boxes_num);
int *pts_assign = NULL;
cudaMalloc(&pts_assign, batch_size * pts_num * boxes_num * sizeof(int)); // (batch_size, N, M)
// cudaMemset(&pts_assign, -1, batch_size * pts_num * boxes_num * sizeof(int));
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
assign_pts_to_box3d<<<blocks, threads>>>(batch_size, pts_num, boxes_num, xyz, boxes3d, pts_assign);
int *pts_idx = NULL;
cudaMalloc(&pts_idx, batch_size * boxes_num * sampled_pts_num * sizeof(int)); // (batch_size, M, sampled_pts_num)
dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), batch_size); // blockIdx.x(col), blockIdx.y(row)
get_pooled_idx<<<blocks2, threads>>>(batch_size, pts_num, boxes_num, sampled_pts_num, pts_assign, pts_idx, pooled_empty_flag);
dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num, batch_size);
roipool3d_forward<<<blocks_pool, threads>>>(batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz, pts_idx, pts_feature, pooled_features, pooled_empty_flag);
cudaFree(pts_assign);
cudaFree(pts_idx);
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
\ No newline at end of file
...@@ -67,6 +67,14 @@ if __name__ == '__main__': ...@@ -67,6 +67,14 @@ if __name__ == '__main__':
'src/roiaware_pool3d_kernel.cu', 'src/roiaware_pool3d_kernel.cu',
] ]
), ),
make_cuda_ext(
name='roipoint_pool3d_cuda',
module='pcdet.ops.roipoint_pool3d',
sources=[
'src/roipoint_pool3d.cpp',
'src/roipoint_pool3d_kernel.cu',
]
),
make_cuda_ext( make_cuda_ext(
name='pointnet2_stack_cuda', name='pointnet2_stack_cuda',
module='pcdet.ops.pointnet2.pointnet2_stack', module='pcdet.ops.pointnet2.pointnet2_stack',
......
CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/kitti_dataset.yaml
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: sample_points
NUM_POINTS: {
'train': 16384,
'test': 16384
}
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': False
}
MODEL:
NAME: PointRCNN
BACKBONE_3D:
NAME: PointNet2MSG
SA_CONFIG:
NPOINTS: [4096, 1024, 256, 64]
RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
NSAMPLE: [[16, 32], [16, 32], [16, 32], [16, 32]]
MLPS: [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]]
FP_MLPS: [[128, 128], [256, 256], [512, 512], [512, 512]]
POINT_HEAD:
NAME: PointHeadBox
CLS_FC: [256, 256]
REG_FC: [256, 256]
CLASS_AGNOSTIC: False
USE_POINT_FEATURES_BEFORE_FUSION: False
TARGET_CONFIG:
GT_EXTRA_WIDTH: [0.2, 0.2, 0.2]
BOX_CODER: PointResidualCoder
BOX_CODER_CONFIG: {
'use_mean_size': True,
'mean_size': [
[3.9, 1.6, 1.56],
[0.8, 0.6, 1.73],
[1.76, 0.6, 1.73]
]
}
LOSS_CONFIG:
LOSS_REG: WeightedSmoothL1Loss
LOSS_WEIGHTS: {
'point_cls_weight': 1.0,
'point_box_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
ROI_HEAD:
NAME: PointRCNNHead
CLASS_AGNOSTIC: True
ROI_POINT_POOL:
POOL_EXTRA_WIDTH: [0.0, 0.0, 0.0]
NUM_SAMPLED_POINTS: 512
DEPTH_NORMALIZER: 70.0
XYZ_UP_LAYER: [128, 128]
CLS_FC: [256, 256]
REG_FC: [256, 256]
DP_RATIO: 0.0
USE_BN: False
SA_CONFIG:
NPOINTS: [128, 32, -1]
RADIUS: [0.2, 0.4, 100]
NSAMPLE: [16, 16, 16]
MLPS: [[128, 128, 128],
[128, 128, 256],
[256, 256, 512]]
NMS_CONFIG:
TRAIN:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 512
NMS_THRESH: 0.8
TEST:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 100
NMS_THRESH: 0.85
TARGET_CONFIG:
BOX_CODER: ResidualCoder
ROI_PER_IMAGE: 128
FG_RATIO: 0.5
SAMPLE_ROI_BY_EACH_CLASS: True
CLS_SCORE_TYPE: cls
CLS_FG_THRESH: 0.6
CLS_BG_THRESH: 0.45
CLS_BG_THRESH_LO: 0.1
HARD_BG_RATIO: 0.8
REG_FG_THRESH: 0.55
LOSS_CONFIG:
CLS_LOSS: BinaryCrossEntropy
REG_LOSS: smooth-l1
CORNER_LOSS_REGULARIZATION: True
LOSS_WEIGHTS: {
'rcnn_cls_weight': 1.0,
'rcnn_reg_weight': 1.0,
'rcnn_corner_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.1
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
OPTIMIZATION:
OPTIMIZER: adam_onecycle
LR: 0.01
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment