Unverified Commit aaf9cbeb authored by Cody Reading's avatar Cody Reading Committed by GitHub
Browse files

Support Monocular 3D Detector CaDDN (#538)

* Added CaDDN detector and support for image, depth map, and 2D GT box
dataloading

* Moved image flip augmentation to augmentor_utils

* Updated default get item list to include points

* Moved utils functions into transform_utils

* Combined FFE + F2V into ImageVFE, renamed FFE to FFN, moved depth downsample into data_processor

* Updated README with updated CaDDN weights

* Updated comments for image vfe
parent e3bec15f
import torchvision
from .ddn_template import DDNTemplate
class DDNDeepLabV3(DDNTemplate):
def __init__(self, backbone_name, **kwargs):
"""
Initializes DDNDeepLabV3 model
Args:
backbone_name: string, ResNet Backbone Name [ResNet50/ResNet101]
"""
if backbone_name == "ResNet50":
constructor = torchvision.models.segmentation.deeplabv3_resnet50
elif backbone_name == "ResNet101":
constructor = torchvision.models.segmentation.deeplabv3_resnet101
else:
raise NotImplementedError
super().__init__(constructor=constructor, **kwargs)
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import kornia
class DDNTemplate(nn.Module):
def __init__(self, constructor, feat_extract_layer, num_classes, pretrained_path=None, aux_loss=None):
"""
Initializes depth distribution network.
Args:
constructor: function, Model constructor
feat_extract_layer: string, Layer to extract features from
num_classes: int, Number of classes
pretrained_path: string, (Optional) Path of the model to load weights from
aux_loss: bool, Flag to include auxillary loss
"""
super().__init__()
self.num_classes = num_classes
self.pretrained_path = pretrained_path
self.pretrained = pretrained_path is not None
self.aux_loss = aux_loss
if self.pretrained:
# Preprocess Module
self.norm_mean = torch.Tensor([0.485, 0.456, 0.406])
self.norm_std = torch.Tensor([0.229, 0.224, 0.225])
# Model
self.model = self.get_model(constructor=constructor)
self.feat_extract_layer = feat_extract_layer
self.model.backbone.return_layers = {
feat_extract_layer: 'features',
**self.model.backbone.return_layers
}
def get_model(self, constructor):
"""
Get model
Args:
constructor: function, Model constructor
Returns:
model: nn.Module, Model
"""
# Get model
model = constructor(pretrained=False,
pretrained_backbone=False,
num_classes=self.num_classes,
aux_loss=self.aux_loss)
# Update weights
if self.pretrained_path is not None:
model_dict = model.state_dict()
# Get pretrained state dict
pretrained_dict = torch.load(self.pretrained_path)
pretrained_dict = self.filter_pretrained_dict(model_dict=model_dict,
pretrained_dict=pretrained_dict)
# Update current model state dict
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def filter_pretrained_dict(self, model_dict, pretrained_dict):
"""
Removes layers from pretrained state dict that are not used or changed in model
Args:
model_dict: dict, Default model state dictionary
pretrained_dict: dict, Pretrained model state dictionary
Returns:
pretrained_dict: dict, Pretrained model state dictionary with removed weights
"""
# Removes aux classifier weights if not used
if "aux_classifier.0.weight" in pretrained_dict and "aux_classifier.0.weight" not in model_dict:
pretrained_dict = {key: value for key, value in pretrained_dict.items()
if "aux_classifier" not in key}
# Removes final conv layer from weights if number of classes are different
model_num_classes = model_dict["classifier.4.weight"].shape[0]
pretrained_num_classes = pretrained_dict["classifier.4.weight"].shape[0]
if model_num_classes != pretrained_num_classes:
pretrained_dict.pop("classifier.4.weight")
pretrained_dict.pop("classifier.4.bias")
return pretrained_dict
def forward(self, images):
"""
Forward pass
Args:
images: (N, 3, H_in, W_in), Input images
Returns
result: dict[torch.Tensor], Depth distribution result
features: (N, C, H_out, W_out), Image features
logits: (N, num_classes, H_out, W_out), Classification logits
aux: (N, num_classes, H_out, W_out), Auxillary classification logits
"""
# Preprocess images
x = self.preprocess(images)
# Extract features
result = OrderedDict()
features = self.model.backbone(x)
result['features'] = features['features']
feat_shape = features['features'].shape[-2:]
# Prediction classification logits
x = features["out"]
x = self.model.classifier(x)
x = F.interpolate(x, size=feat_shape, mode='bilinear', align_corners=False)
result["logits"] = x
# Prediction auxillary classification logits
if self.model.aux_classifier is not None:
x = features["aux"]
x = self.model.aux_classifier(x)
x = F.interpolate(x, size=feat_shape, mode='bilinear', align_corners=False)
result["aux"] = x
return result
def preprocess(self, images):
"""
Preprocess images
Args:
images: (N, 3, H, W), Input images
Return
x: (N, 3, H, W), Preprocessed images
"""
x = images
if self.pretrained:
# Create a mask for padded pixels
mask = torch.isnan(x)
# Match ResNet pretrained preprocessing
x = kornia.normalize(x, mean=self.norm_mean, std=self.norm_std)
# Make padded pixels = 0
x[mask] = 0
return x
from .ddn_loss import DDNLoss
__all__ = {
"DDNLoss": DDNLoss
}
import torch
import torch.nn as nn
from pcdet.utils import loss_utils
class Balancer(nn.Module):
def __init__(self, fg_weight, bg_weight, downsample_factor=1):
"""
Initialize fixed foreground/background loss balancer
Args:
fg_weight: float, Foreground loss weight
bg_weight: float, Background loss weight
downsample_factor: int, Depth map downsample factor
"""
super().__init__()
self.fg_weight = fg_weight
self.bg_weight = bg_weight
self.downsample_factor = downsample_factor
def forward(self, loss, gt_boxes2d):
"""
Forward pass
Args:
loss: (B, H, W), Pixel-wise loss
gt_boxes2d: (B, N, 4), 2D box labels for foreground/background balancing
Returns:
loss: (1), Total loss after foreground/background balancing
tb_dict: dict[float], All losses to log in tensorboard
"""
# Compute masks
fg_mask = loss_utils.compute_fg_mask(gt_boxes2d=gt_boxes2d,
shape=loss.shape,
downsample_factor=self.downsample_factor,
device=loss.device)
bg_mask = ~fg_mask
# Compute balancing weights
weights = self.fg_weight * fg_mask + self.bg_weight * bg_mask
num_pixels = fg_mask.sum() + bg_mask.sum()
# Compute losses
loss *= weights
fg_loss = loss[fg_mask].sum() / num_pixels
bg_loss = loss[bg_mask].sum() / num_pixels
# Get total loss
loss = fg_loss + bg_loss
tb_dict = {"balancer_loss": loss.item(), "fg_loss": fg_loss.item(), "bg_loss": bg_loss.item()}
return loss, tb_dict
import torch
import torch.nn as nn
import kornia
from .balancer import Balancer
from pcdet.utils import transform_utils
class DDNLoss(nn.Module):
def __init__(self,
weight,
alpha,
gamma,
disc_cfg,
fg_weight,
bg_weight,
downsample_factor):
"""
Initializes DDNLoss module
Args:
weight: float, Loss function weight
alpha: float, Alpha value for Focal Loss
gamma: float, Gamma value for Focal Loss
disc_cfg: dict, Depth discretiziation configuration
fg_weight: float, Foreground loss weight
bg_weight: float, Background loss weight
downsample_factor: int, Depth map downsample factor
"""
super().__init__()
self.device = torch.cuda.current_device()
self.disc_cfg = disc_cfg
self.balancer = Balancer(downsample_factor=downsample_factor,
fg_weight=fg_weight,
bg_weight=bg_weight)
# Set loss function
self.alpha = alpha
self.gamma = gamma
self.loss_func = kornia.losses.FocalLoss(alpha=self.alpha, gamma=self.gamma, reduction="none")
self.weight = weight
def forward(self, depth_logits, depth_maps, gt_boxes2d):
"""
Gets DDN loss
Args:
depth_logits: (B, D+1, H, W), Predicted depth logits
depth_maps: (B, H, W), Depth map [m]
gt_boxes2d: torch.Tensor (B, N, 4), 2D box labels for foreground/background balancing
Returns:
loss: (1), Depth distribution network loss
tb_dict: dict[float], All losses to log in tensorboard
"""
tb_dict = {}
# Bin depth map to create target
depth_target = transform_utils.bin_depths(depth_maps, **self.disc_cfg, target=True)
# Compute loss
loss = self.loss_func(depth_logits, depth_target)
# Compute foreground/background balancing
loss, tb_dict = self.balancer(loss=loss, gt_boxes2d=gt_boxes2d)
# Final loss
loss *= self.weight
tb_dict.update({"ddn_loss": loss.item()})
return loss, tb_dict
import torch.nn as nn
import torch.nn.functional as F
from . import ddn, ddn_loss
from pcdet.models.model_utils.basic_block_2d import BasicBlock2D
class DepthFFN(nn.Module):
def __init__(self, model_cfg, downsample_factor):
"""
Initialize frustum feature network via depth distribution estimation
Args:
model_cfg: EasyDict, Depth classification network config
downsample_factor: int, Depth map downsample factor
"""
super().__init__()
self.model_cfg = model_cfg
self.disc_cfg = model_cfg.DISCRETIZE
self.downsample_factor = downsample_factor
# Create modules
self.ddn = ddn.__all__[model_cfg.DDN.NAME](
num_classes=self.disc_cfg["num_bins"] + 1,
backbone_name=model_cfg.DDN.BACKBONE_NAME,
**model_cfg.DDN.ARGS
)
self.channel_reduce = BasicBlock2D(**model_cfg.CHANNEL_REDUCE)
self.ddn_loss = ddn_loss.__all__[model_cfg.LOSS.NAME](
disc_cfg=self.disc_cfg,
downsample_factor=downsample_factor,
**model_cfg.LOSS.ARGS
)
self.forward_ret_dict = {}
def get_output_feature_dim(self):
return self.channel_reduce.out_channels
def forward(self, batch_dict):
"""
Predicts depths and creates image depth feature volume using depth distributions
Args:
batch_dict:
images: (N, 3, H_in, W_in), Input images
Returns:
batch_dict:
frustum_features: (N, C, D, H_out, W_out), Image depth features
"""
# Pixel-wise depth classification
images = batch_dict["images"]
ddn_result = self.ddn(images)
image_features = ddn_result["features"]
depth_logits = ddn_result["logits"]
# Channel reduce
if self.channel_reduce is not None:
image_features = self.channel_reduce(image_features)
# Create image feature plane-sweep volume
frustum_features = self.create_frustum_features(image_features=image_features,
depth_logits=depth_logits)
batch_dict["frustum_features"] = frustum_features
if self.training:
self.forward_ret_dict["depth_maps"] = batch_dict["depth_maps"]
self.forward_ret_dict["gt_boxes2d"] = batch_dict["gt_boxes2d"]
self.forward_ret_dict["depth_logits"] = depth_logits
return batch_dict
def create_frustum_features(self, image_features, depth_logits):
"""
Create image depth feature volume by multiplying image features with depth distributions
Args:
image_features: (N, C, H, W), Image features
depth_logits: (N, D+1, H, W), Depth classification logits
Returns:
frustum_features: (N, C, D, H, W), Image features
"""
channel_dim = 1
depth_dim = 2
# Resize to match dimensions
image_features = image_features.unsqueeze(depth_dim)
depth_logits = depth_logits.unsqueeze(channel_dim)
# Apply softmax along depth axis and remove last depth category (> Max Range)
depth_probs = F.softmax(depth_logits, dim=depth_dim)
depth_probs = depth_probs[:, :, :-1]
# Multiply to form image depth feature volume
frustum_features = depth_probs * image_features
return frustum_features
def get_loss(self):
"""
Gets DDN loss
Args:
Returns:
loss: (1), Depth distribution network loss
tb_dict: dict[float], All losses to log in tensorboard
"""
loss, tb_dict = self.ddn_loss(**self.forward_ret_dict)
return loss, tb_dict
......@@ -50,7 +50,7 @@ class PFNLayer(nn.Module):
class PillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range, **kwargs):
super().__init__(model_cfg=model_cfg)
self.use_norm = self.model_cfg.USE_NORM
......
......@@ -5,6 +5,7 @@ from .pointpillar import PointPillar
from .pv_rcnn import PVRCNN
from .second_net import SECONDNet
from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
......@@ -13,7 +14,8 @@ __all__ = {
'PVRCNN': PVRCNN,
'PointPillar': PointPillar,
'PointRCNN': PointRCNN,
'SECONDNetIoU': SECONDNetIoU
'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN
}
......
from .detector3d_template import Detector3DTemplate
class CaDDN(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict_rpn = self.dense_head.get_loss()
loss_depth, tb_dict_depth = self.vfe.get_loss()
tb_dict = {
'loss_rpn': loss_rpn.item(),
'loss_depth': loss_depth.item(),
**tb_dict_rpn,
**tb_dict_depth
}
loss = loss_rpn + loss_depth
return loss, tb_dict, disp_dict
......@@ -38,7 +38,8 @@ class Detector3DTemplate(nn.Module):
'num_point_features': self.dataset.point_feature_encoder.num_point_features,
'grid_size': self.dataset.grid_size,
'point_cloud_range': self.dataset.point_cloud_range,
'voxel_size': self.dataset.voxel_size
'voxel_size': self.dataset.voxel_size,
'depth_downsample_factor': self.dataset.depth_downsample_factor
}
for module_name in self.module_topology:
module, model_info_dict = getattr(self, 'build_%s' % module_name)(
......@@ -55,7 +56,9 @@ class Detector3DTemplate(nn.Module):
model_cfg=self.model_cfg.VFE,
num_point_features=model_info_dict['num_rawpoint_features'],
point_cloud_range=model_info_dict['point_cloud_range'],
voxel_size=model_info_dict['voxel_size']
voxel_size=model_info_dict['voxel_size'],
grid_size=model_info_dict['grid_size'],
depth_downsample_factor=model_info_dict['depth_downsample_factor']
)
model_info_dict['num_point_features'] = vfe_module.get_output_feature_dim()
model_info_dict['module_list'].append(vfe_module)
......
import torch.nn as nn
class BasicBlock2D(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
"""
Initializes convolutional block
Args:
in_channels: int, Number of input channels
out_channels: int, Number of output channels
**kwargs: Dict, Extra arguments for nn.Conv2d
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
**kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, features):
"""
Applies convolutional block
Args:
features: (B, C_in, H, W), Input features
Returns:
x: (B, C_out, H, W), Output features
"""
x = self.conv(features)
x = self.bn(x)
x = self.relu(x)
return x
......@@ -105,6 +105,24 @@ def set_random_seed(seed):
torch.backends.cudnn.benchmark = False
def get_pad_params(desired_size, cur_size):
"""
Get padding parameters for np.pad function
Args:
desired_size: int, Desired padded output size
cur_size: int, Current size. Should always be less than or equal to cur_size
Returns:
pad_params: tuple(int), Number of values padded to the edges (before, after)
"""
assert desired_size >= cur_size
# Calculate amount to pad
diff = desired_size - cur_size
pad_params = (0, diff)
return pad_params
def keep_arrays_by_name(gt_names, used_classes):
inds = [i for i, x in enumerate(gt_names) if x in used_classes]
inds = np.array(inds, dtype=np.int64)
......
......@@ -230,3 +230,32 @@ def get_corner_loss_lidar(pred_bbox3d: torch.Tensor, gt_bbox3d: torch.Tensor):
corner_loss = WeightedSmoothL1Loss.smooth_l1_loss(corner_dist, beta=1.0)
return corner_loss.mean(dim=1)
def compute_fg_mask(gt_boxes2d, shape, downsample_factor=1, device=torch.device("cpu")):
"""
Compute foreground mask for images
Args:
gt_boxes2d: (B, N, 4), 2D box labels
shape: torch.Size or tuple, Foreground mask desired shape
downsample_factor: int, Downsample factor for image
device: torch.device, Foreground mask desired device
Returns:
fg_mask (shape), Foreground mask
"""
fg_mask = torch.zeros(shape, dtype=torch.bool, device=device)
# Set box corners
gt_boxes2d /= downsample_factor
gt_boxes2d[:, :, :2] = torch.floor(gt_boxes2d[:, :, :2])
gt_boxes2d[:, :, 2:] = torch.ceil(gt_boxes2d[:, :, 2:])
gt_boxes2d = gt_boxes2d.long()
# Set all values within each box to True
B, N = gt_boxes2d.shape[:2]
for b in range(B):
for n in range(N):
u1, v1, u2, v2 = gt_boxes2d[b, n]
fg_mask[b, v1:v2, u1:u2] = True
return fg_mask
import math
import torch
import kornia
def project_to_image(project, points):
"""
Project points to image
Args:
project [torch.tensor(..., 3, 4)]: Projection matrix
points [torch.Tensor(..., 3)]: 3D points
Returns:
points_img [torch.Tensor(..., 2)]: Points in image
points_depth [torch.Tensor(...)]: Depth of each point
"""
# Reshape tensors to expected shape
points = kornia.convert_points_to_homogeneous(points)
points = points.unsqueeze(dim=-1)
project = project.unsqueeze(dim=1)
# Transform points to image and get depths
points_t = project @ points
points_t = points_t.squeeze(dim=-1)
points_img = kornia.convert_points_from_homogeneous(points_t)
points_depth = points_t[..., -1] - project[..., 2, 3]
return points_img, points_depth
def normalize_coords(coords, shape):
"""
Normalize coordinates of a grid between [-1, 1]
Args:
coords: (..., 3), Coordinates in grid
shape: (3), Grid shape
Returns:
norm_coords: (.., 3), Normalized coordinates in grid
"""
min_n = -1
max_n = 1
shape = torch.flip(shape, dims=[0]) # Reverse ordering of shape
# Subtract 1 since pixel indexing from [0, shape - 1]
norm_coords = coords / (shape - 1) * (max_n - min_n) + min_n
return norm_coords
def bin_depths(depth_map, mode, depth_min, depth_max, num_bins, target=False):
"""
Converts depth map into bin indices
Args:
depth_map: (H, W), Depth Map
mode: string, Discretiziation mode (See https://arxiv.org/pdf/2005.13423.pdf for more details)
UD: Uniform discretiziation
LID: Linear increasing discretiziation
SID: Spacing increasing discretiziation
depth_min: float, Minimum depth value
depth_max: float, Maximum depth value
num_bins: int, Number of depth bins
target: bool, Whether the depth bins indices will be used for a target tensor in loss comparison
Returns:
indices: (H, W), Depth bin indices
"""
if mode == "UD":
bin_size = (depth_max - depth_min) / num_bins
indices = ((depth_map - depth_min) / bin_size)
elif mode == "LID":
bin_size = 2 * (depth_max - depth_min) / (num_bins * (1 + num_bins))
indices = -0.5 + 0.5 * torch.sqrt(1 + 8 * (depth_map - depth_min) / bin_size)
elif mode == "SID":
indices = num_bins * (torch.log(1 + depth_map) - math.log(1 + depth_min)) / \
(math.log(1 + depth_max) - math.log(1 + depth_min))
else:
raise NotImplementedError
if target:
# Remove indicies outside of bounds
mask = (indices < 0) | (indices > num_bins) | (~torch.isfinite(indices))
indices[mask] = num_bins
# Convert to integer
indices = indices.type(torch.int64)
return indices
......@@ -13,9 +13,9 @@ INFO_PATH: {
'test': [kitti_infos_val.pkl],
}
GET_ITEM_LIST: ["points"]
FOV_POINTS_ONLY: True
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
......
CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/kitti_dataset.yaml
POINT_CLOUD_RANGE: [2, -30.08, -3.0, 46.8, 30.08, 1.0]
GET_ITEM_LIST: ["images", "depth_maps", "calib_matricies", "gt_boxes2d"]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: calculate_grid_size
VOXEL_SIZE: [0.16, 0.16, 0.16]
- NAME: downsample_depth_map
DOWNSAMPLE_FACTOR: 4
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: random_image_flip
ALONG_AXIS_LIST: ['horizontal']
MODEL:
NAME: CaDDN
VFE:
NAME: ImageVFE
FFN:
NAME: DepthFFN
DDN:
NAME: DDNDeepLabV3
BACKBONE_NAME: ResNet101
ARGS: {
"feat_extract_layer": "layer1",
"pretrained_path": "../checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth"
}
CHANNEL_REDUCE: {
"in_channels": 256,
"out_channels": 64,
"kernel_size": 1,
"stride": 1,
"bias": False
}
DISCRETIZE: {
"mode": LID,
"num_bins": 80,
"depth_min": 2.0,
"depth_max": 46.8
}
LOSS:
NAME: DDNLoss
ARGS: {
'weight': 3.0,
'alpha': 0.25,
'gamma': 2.0,
'fg_weight': 13,
'bg_weight': 1
}
F2V:
NAME: FrustumToVoxel
SAMPLER: {
"mode": "bilinear",
"padding_mode": "zeros"
}
MAP_TO_BEV:
NAME: Conv2DCollapse
NUM_BEV_FEATURES: 64
ARGS: {
"kernel_size": 1,
"stride": 1,
"bias": False
}
BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [10, 10, 10]
LAYER_STRIDES: [2, 2, 2]
NUM_FILTERS: [64, 128, 256]
UPSAMPLE_STRIDES: [1, 2, 4]
NUM_UPSAMPLE_FILTERS: [128, 128, 128]
DENSE_HEAD:
NAME: AnchorHeadSingle
CLASS_AGNOSTIC: False
USE_DIRECTION_CLASSIFIER: True
DIR_OFFSET: 0.78539
DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2
ANCHOR_GENERATOR_CONFIG: [
{
'class_name': 'Car',
'anchor_sizes': [[3.9, 1.6, 1.56]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.78],
'align_center': False,
'feature_map_stride': 2,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
},
{
'class_name': 'Pedestrian',
'anchor_sizes': [[0.8, 0.6, 1.73]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6],
'align_center': False,
'feature_map_stride': 2,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
},
{
'class_name': 'Cyclist',
'anchor_sizes': [[1.76, 0.6, 1.73]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6],
'align_center': False,
'feature_map_stride': 2,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}
]
TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0
SAMPLE_SIZE: 512
NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False
BOX_CODER: ResidualCoder
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.01
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 80
OPTIMIZER: adam_onecycle
LR: 0.001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment